@@ -73,14 +73,20 @@ class Jais2ModelTest(CausalLMModelTest, unittest.TestCase):
7373 )
7474
7575
76- JAIS2_8B_CHECKPOINT = "inceptionai/jais -2-8b "
76+ JAIS2_8B_CHECKPOINT = "inceptionai/Jais -2-8B-Chat "
7777
7878
7979@require_torch
8080class Jais2IntegrationTest (unittest .TestCase ):
81- # Update this path to your local checkpoint
8281 checkpoint = JAIS2_8B_CHECKPOINT
8382
83+ def setUp (self ):
84+ self .tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
85+ if self .tokenizer .chat_template is None :
86+ self .tokenizer .chat_template = (
87+ "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\n ' }}{% endfor %}"
88+ )
89+
8490 def tearDown (self ):
8591 backend_empty_cache (torch_device )
8692 gc .collect ()
@@ -94,10 +100,9 @@ def test_model_logits(self):
94100 device_map = "auto" ,
95101 torch_dtype = torch .float16 ,
96102 )
97- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
98103
99104 input_text = "The capital of France is"
100- input_ids = tokenizer .encode (input_text , return_tensors = "pt" ).to (model .device )
105+ input_ids = self . tokenizer .encode (input_text , return_tensors = "pt" ).to (model .device )
101106
102107 with torch .no_grad ():
103108 outputs = model (input_ids )
@@ -129,10 +134,9 @@ def test_model_logits_bf16(self):
129134 device_map = "auto" ,
130135 torch_dtype = torch .bfloat16 ,
131136 )
132- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
133137
134138 input_text = "The capital of France is"
135- input_ids = tokenizer .encode (input_text , return_tensors = "pt" ).to (model .device )
139+ input_ids = self . tokenizer .encode (input_text , return_tensors = "pt" ).to (model .device )
136140
137141 with torch .no_grad ():
138142 outputs = model (input_ids )
@@ -160,10 +164,9 @@ def test_model_generation(self):
160164 device_map = "auto" ,
161165 torch_dtype = torch .float16 ,
162166 )
163- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
164167
165168 prompt = "The capital of France is"
166- input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
169+ input_ids = self . tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
167170
168171 # Greedy generation
169172 generated_ids = model .generate (
@@ -172,7 +175,7 @@ def test_model_generation(self):
172175 do_sample = False ,
173176 )
174177
175- generated_text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
178+ generated_text = self . tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
176179 print (f"Generated text: { generated_text } " )
177180
178181 # Check that generation produced new tokens
@@ -195,18 +198,17 @@ def test_model_generation_sdpa(self):
195198 torch_dtype = torch .float16 ,
196199 attn_implementation = "sdpa" ,
197200 )
198- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
199201
200202 prompt = "Artificial intelligence is"
201- input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
203+ input_ids = self . tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
202204
203205 generated_ids = model .generate (
204206 input_ids ,
205207 max_new_tokens = 20 ,
206208 do_sample = False ,
207209 )
208210
209- generated_text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
211+ generated_text = self . tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
210212 print (f"SDPA Generated text: { generated_text } " )
211213
212214 self .assertGreater (generated_ids .shape [1 ], input_ids .shape [1 ])
@@ -228,18 +230,17 @@ def test_model_generation_flash_attn(self):
228230 torch_dtype = torch .float16 ,
229231 attn_implementation = "flash_attention_2" ,
230232 )
231- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
232233
233234 prompt = "Machine learning models are"
234- input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
235+ input_ids = self . tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
235236
236237 generated_ids = model .generate (
237238 input_ids ,
238239 max_new_tokens = 20 ,
239240 do_sample = False ,
240241 )
241242
242- generated_text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
243+ generated_text = self . tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
243244 print (f"Flash Attention Generated text: { generated_text } " )
244245
245246 self .assertGreater (generated_ids .shape [1 ], input_ids .shape [1 ])
@@ -279,9 +280,8 @@ def test_layer_norm(self):
279280 @require_torch_accelerator
280281 def test_attention_implementations_consistency (self ):
281282 """Test that different attention implementations produce similar outputs."""
282- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
283283 prompt = "Hello, how are you?"
284- input_ids = tokenizer .encode (prompt , return_tensors = "pt" )
284+ input_ids = self . tokenizer .encode (prompt , return_tensors = "pt" )
285285
286286 # Test with eager attention
287287 model_eager = Jais2ForCausalLM .from_pretrained (
@@ -328,10 +328,9 @@ def test_compile_static_cache(self):
328328 device_map = "auto" ,
329329 torch_dtype = torch .float16 ,
330330 )
331- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
332331
333332 prompt = "The future of AI is"
334- input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
333+ input_ids = self . tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
335334
336335 # Generate with static cache
337336 generated_ids = model .generate (
@@ -341,7 +340,7 @@ def test_compile_static_cache(self):
341340 cache_implementation = "static" ,
342341 )
343342
344- generated_text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
343+ generated_text = self . tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
345344 print (f"Static cache generated text: { generated_text } " )
346345
347346 self .assertGreater (generated_ids .shape [1 ], input_ids .shape [1 ])
@@ -360,10 +359,9 @@ def test_export_static_cache(self):
360359 device_map = "auto" ,
361360 torch_dtype = torch .float16 ,
362361 )
363- tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
364362
365363 prompt = "Deep learning is"
366- input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
364+ input_ids = self . tokenizer .encode (prompt , return_tensors = "pt" ).to (model .device )
367365
368366 # First verify regular generation works
369367 generated_ids = model .generate (
@@ -372,7 +370,7 @@ def test_export_static_cache(self):
372370 do_sample = False ,
373371 )
374372
375- generated_text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
373+ generated_text = self . tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
376374 print (f"Export test generated text: { generated_text } " )
377375
378376 self .assertGreater (generated_ids .shape [1 ], input_ids .shape [1 ])
0 commit comments