Skip to content

Commit 9e0839b

Browse files
updates tests
1 parent 672e38a commit 9e0839b

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

tests/models/jais2/test_modeling_jais2.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,20 @@ class Jais2ModelTest(CausalLMModelTest, unittest.TestCase):
7373
)
7474

7575

76-
JAIS2_8B_CHECKPOINT = "inceptionai/jais-2-8b"
76+
JAIS2_8B_CHECKPOINT = "inceptionai/Jais-2-8B-Chat"
7777

7878

7979
@require_torch
8080
class Jais2IntegrationTest(unittest.TestCase):
81-
# Update this path to your local checkpoint
8281
checkpoint = JAIS2_8B_CHECKPOINT
8382

83+
def setUp(self):
84+
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
85+
if self.tokenizer.chat_template is None:
86+
self.tokenizer.chat_template = (
87+
"{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\n' }}{% endfor %}"
88+
)
89+
8490
def tearDown(self):
8591
backend_empty_cache(torch_device)
8692
gc.collect()
@@ -94,10 +100,9 @@ def test_model_logits(self):
94100
device_map="auto",
95101
torch_dtype=torch.float16,
96102
)
97-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
98103

99104
input_text = "The capital of France is"
100-
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
105+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(model.device)
101106

102107
with torch.no_grad():
103108
outputs = model(input_ids)
@@ -129,10 +134,9 @@ def test_model_logits_bf16(self):
129134
device_map="auto",
130135
torch_dtype=torch.bfloat16,
131136
)
132-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
133137

134138
input_text = "The capital of France is"
135-
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
139+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(model.device)
136140

137141
with torch.no_grad():
138142
outputs = model(input_ids)
@@ -160,10 +164,9 @@ def test_model_generation(self):
160164
device_map="auto",
161165
torch_dtype=torch.float16,
162166
)
163-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
164167

165168
prompt = "The capital of France is"
166-
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
169+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(model.device)
167170

168171
# Greedy generation
169172
generated_ids = model.generate(
@@ -172,7 +175,7 @@ def test_model_generation(self):
172175
do_sample=False,
173176
)
174177

175-
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
178+
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
176179
print(f"Generated text: {generated_text}")
177180

178181
# Check that generation produced new tokens
@@ -195,18 +198,17 @@ def test_model_generation_sdpa(self):
195198
torch_dtype=torch.float16,
196199
attn_implementation="sdpa",
197200
)
198-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
199201

200202
prompt = "Artificial intelligence is"
201-
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
203+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(model.device)
202204

203205
generated_ids = model.generate(
204206
input_ids,
205207
max_new_tokens=20,
206208
do_sample=False,
207209
)
208210

209-
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
211+
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
210212
print(f"SDPA Generated text: {generated_text}")
211213

212214
self.assertGreater(generated_ids.shape[1], input_ids.shape[1])
@@ -228,18 +230,17 @@ def test_model_generation_flash_attn(self):
228230
torch_dtype=torch.float16,
229231
attn_implementation="flash_attention_2",
230232
)
231-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
232233

233234
prompt = "Machine learning models are"
234-
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
235+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(model.device)
235236

236237
generated_ids = model.generate(
237238
input_ids,
238239
max_new_tokens=20,
239240
do_sample=False,
240241
)
241242

242-
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
243+
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
243244
print(f"Flash Attention Generated text: {generated_text}")
244245

245246
self.assertGreater(generated_ids.shape[1], input_ids.shape[1])
@@ -279,9 +280,8 @@ def test_layer_norm(self):
279280
@require_torch_accelerator
280281
def test_attention_implementations_consistency(self):
281282
"""Test that different attention implementations produce similar outputs."""
282-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
283283
prompt = "Hello, how are you?"
284-
input_ids = tokenizer.encode(prompt, return_tensors="pt")
284+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
285285

286286
# Test with eager attention
287287
model_eager = Jais2ForCausalLM.from_pretrained(
@@ -328,10 +328,9 @@ def test_compile_static_cache(self):
328328
device_map="auto",
329329
torch_dtype=torch.float16,
330330
)
331-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
332331

333332
prompt = "The future of AI is"
334-
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
333+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(model.device)
335334

336335
# Generate with static cache
337336
generated_ids = model.generate(
@@ -341,7 +340,7 @@ def test_compile_static_cache(self):
341340
cache_implementation="static",
342341
)
343342

344-
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
343+
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
345344
print(f"Static cache generated text: {generated_text}")
346345

347346
self.assertGreater(generated_ids.shape[1], input_ids.shape[1])
@@ -360,10 +359,9 @@ def test_export_static_cache(self):
360359
device_map="auto",
361360
torch_dtype=torch.float16,
362361
)
363-
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
364362

365363
prompt = "Deep learning is"
366-
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
364+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(model.device)
367365

368366
# First verify regular generation works
369367
generated_ids = model.generate(
@@ -372,7 +370,7 @@ def test_export_static_cache(self):
372370
do_sample=False,
373371
)
374372

375-
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
373+
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
376374
print(f"Export test generated text: {generated_text}")
377375

378376
self.assertGreater(generated_ids.shape[1], input_ids.shape[1])

0 commit comments

Comments
 (0)