|
30 | 30 | # Tests are skipped if model paths are not configured. |
31 | 31 | # END-GOOGLE-INTERNAL |
32 | 32 |
|
| 33 | +import os |
33 | 34 | import unittest |
34 | 35 |
|
| 36 | +from absl.testing import absltest |
35 | 37 | from absl.testing import parameterized |
36 | | -from flax import nnx |
37 | | -import jax |
38 | | -from tunix.models.gemma3 import params as gemma3_params_lib |
39 | | -import numpy as np |
40 | 38 | from flax.traverse_util import flatten_dict |
41 | | - |
| 39 | +import jax.numpy as jnp |
| 40 | +import numpy as np |
| 41 | +import safetensors.numpy as safe_np |
| 42 | +from tunix.tests import lora_params_test_base |
| 43 | +from tunix.models.gemma3 import model as gemma3_model |
| 44 | +from tunix.models.gemma3 import params as gemma3_params |
| 45 | +from tunix.models.gemma3 import params_safetensors as gemma3_params_safetensors |
| 46 | +from tunix.tests import test_common |
42 | 47 |
|
43 | 48 | class GemmaParamsTest(parameterized.TestCase): |
44 | 49 |
|
@@ -90,7 +95,7 @@ def test_map_from_upstream_checkpoint(self, model_type): |
90 | 95 | "transformer/layer_0/pre_ffw_norm": {"scale": pre_ffw}, |
91 | 96 | } |
92 | 97 |
|
93 | | - mapped = gemma3_params_lib.map_from_upstream_checkpoint(upstream, model_type) |
| 98 | + mapped = gemma3_params.map_from_upstream_checkpoint(upstream, model_type) |
94 | 99 | flat_m = flatten_dict(mapped) # tuple keys |
95 | 100 |
|
96 | 101 | # --- Keys & shapes we expect after mapping (tiny) --- |
@@ -174,14 +179,268 @@ def test_map_from_upstream_checkpoint(self, model_type): |
174 | 179 | assert not any(k[0] == 'siglip_encoder' for k in flat_m.keys()) |
175 | 180 | assert ('embedder', 'mm_patch') not in mapped.get('embedder', {}) |
176 | 181 |
|
| 182 | + |
| 183 | +class Gemma3LoraParamsTest(lora_params_test_base.LoraParamsTestBase): |
| 184 | + """Tests for Gemma3 LoRA merged model saving and loading.""" |
| 185 | + |
| 186 | + def create_config(self): |
| 187 | + """Create Gemma3 model config for testing.""" |
| 188 | + return gemma3_model.ModelConfig( |
| 189 | + num_layers=2, |
| 190 | + num_embed=256, |
| 191 | + embed_dim=64, |
| 192 | + hidden_dim=128, |
| 193 | + num_heads=4, |
| 194 | + head_dim=16, |
| 195 | + num_kv_heads=1, |
| 196 | + sliding_window_size=128, # Required for LOCAL_SLIDING attention |
| 197 | + ) |
| 198 | + |
| 199 | + def get_model_class(self): |
| 200 | + """Get Gemma3 model class.""" |
| 201 | + return gemma3_model.Gemma3 |
| 202 | + |
| 203 | + def get_lora_module_path(self) -> str: |
| 204 | + """Get LoRA target modules for Gemma3.""" |
| 205 | + return '.*q_einsum|.*kv_einsum|.*attn_vec_einsum|.*gate_proj|.*up_proj|.*down_proj' |
| 206 | + |
| 207 | + def get_projection_keys(self, layer_idx: int) -> list[str]: |
| 208 | + """Get projection keys for Gemma3.""" |
| 209 | + prefix = f'model.layers.{layer_idx}' |
| 210 | + return [ |
| 211 | + f'{prefix}.self_attn.q_proj.weight', |
| 212 | + f'{prefix}.self_attn.k_proj.weight', |
| 213 | + f'{prefix}.self_attn.v_proj.weight', |
| 214 | + f'{prefix}.self_attn.o_proj.weight', |
| 215 | + f'{prefix}.mlp.gate_proj.weight', |
| 216 | + f'{prefix}.mlp.up_proj.weight', |
| 217 | + f'{prefix}.mlp.down_proj.weight', |
| 218 | + ] |
| 219 | + |
| 220 | + def save_merged_model(self, lora_model): |
| 221 | + """Save Gemma3 LoRA merged model.""" |
| 222 | + gemma3_params.save_lora_merged_model_as_safetensors( |
| 223 | + local_model_path=self.base_checkpoint_dir, |
| 224 | + output_dir=self.merged_output_dir, |
| 225 | + lora_model=lora_model, |
| 226 | + rank=self.rank, |
| 227 | + alpha=self.alpha, |
| 228 | + ) |
| 229 | + |
| 230 | + def create_model_from_checkpoint(self, checkpoint_dir: str): |
| 231 | + """Load Gemma3 model from checkpoint.""" |
| 232 | + return gemma3_params_safetensors.create_model_from_safe_tensors( |
| 233 | + file_dir=checkpoint_dir, |
| 234 | + config=self.config, |
| 235 | + mesh=None, |
| 236 | + dtype=jnp.float32, |
| 237 | + ) |
| 238 | + |
| 239 | + def _create_test_inputs(self): |
| 240 | + """Create test inputs for Gemma3 forward pass.""" |
| 241 | + batch_size = 2 |
| 242 | + seq_len = 10 |
| 243 | + |
| 244 | + input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) |
| 245 | + positions = jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0) |
| 246 | + # Gemma3 requires attention mask - create causal mask |
| 247 | + attention_mask = jnp.tril(jnp.ones((batch_size, seq_len, seq_len))) |
| 248 | + |
| 249 | + return input_tokens, positions, attention_mask |
| 250 | + |
| 251 | + def _run_forward_pass(self, model, input_tokens, positions, attention_mask): |
| 252 | + """Run forward pass through Gemma3 model.""" |
| 253 | + # Gemma3 uses `last_tokens` parameter name |
| 254 | + return model( |
| 255 | + last_tokens=input_tokens, |
| 256 | + positions=positions, |
| 257 | + cache=None, |
| 258 | + attention_mask=attention_mask, |
| 259 | + ) |
| 260 | + |
| 261 | + def create_checkpoint(self, model) -> str: |
| 262 | + """Extract model weights and save in safetensors format. |
| 263 | +
|
| 264 | + Uses the model's actual weights and applies inverse transformations |
| 265 | + to create a valid safetensors file compatible with Gemma3 loader. |
| 266 | +
|
| 267 | + Key difference from Qwen3: kv_einsum must be decomposed into k_proj and |
| 268 | + v_proj. |
| 269 | +
|
| 270 | + Args: |
| 271 | + model: Base model to extract weights from. |
| 272 | +
|
| 273 | + Returns: |
| 274 | + Path to the created checkpoint directory. |
| 275 | + """ |
| 276 | + os.makedirs(self.base_checkpoint_dir, exist_ok=True) |
| 277 | + |
| 278 | + base_state = {} |
| 279 | + |
| 280 | + # Embedder (no transformation needed) |
| 281 | + base_state['model.embed_tokens.weight'] = np.array( |
| 282 | + model.embedder.input_embedding.value |
| 283 | + ) |
| 284 | + |
| 285 | + # Final norm (no transformation needed) |
| 286 | + base_state['model.norm.weight'] = np.array(model.final_norm.scale.value) |
| 287 | + |
| 288 | + # Extract and transform weights for all layers |
| 289 | + for layer_idx, layer in enumerate(model.layers): |
| 290 | + prefix = f'model.layers.{layer_idx}' |
| 291 | + |
| 292 | + # Layer norms (no transformation needed) |
| 293 | + base_state[f'{prefix}.input_layernorm.weight'] = np.array( |
| 294 | + layer.pre_attention_norm.scale.value |
| 295 | + ) |
| 296 | + base_state[f'{prefix}.post_attention_layernorm.weight'] = np.array( |
| 297 | + layer.post_attention_norm.scale.value |
| 298 | + ) |
| 299 | + base_state[f'{prefix}.pre_feedforward_layernorm.weight'] = np.array( |
| 300 | + layer.pre_ffw_norm.scale.value |
| 301 | + ) |
| 302 | + base_state[f'{prefix}.post_feedforward_layernorm.weight'] = np.array( |
| 303 | + layer.post_ffw_norm.scale.value |
| 304 | + ) |
| 305 | + |
| 306 | + # Query/Key norms (no transformation needed) |
| 307 | + base_state[f'{prefix}.self_attn.q_norm.weight'] = np.array( |
| 308 | + layer.attn._query_norm.scale.value |
| 309 | + ) |
| 310 | + base_state[f'{prefix}.self_attn.k_norm.weight'] = np.array( |
| 311 | + layer.attn._key_norm.scale.value |
| 312 | + ) |
| 313 | + |
| 314 | + # Attention projections |
| 315 | + # q_einsum: nnx (num_heads, embed_dim, head_dim) → safetensors (num_heads*head_dim, embed_dim) |
| 316 | + if hasattr(layer.attn, 'q_einsum'): |
| 317 | + w = np.array( |
| 318 | + layer.attn.q_einsum.w.value |
| 319 | + ) # (num_heads, embed_dim, head_dim) |
| 320 | + w = w.transpose(0, 2, 1) # (num_heads, head_dim, embed_dim) |
| 321 | + w = w.reshape( |
| 322 | + -1, self.config.embed_dim |
| 323 | + ) # (num_heads*head_dim, embed_dim) |
| 324 | + base_state[f'{prefix}.self_attn.q_proj.weight'] = w |
| 325 | + |
| 326 | + # kv_einsum: nnx (2, num_kv_heads, embed_dim, head_dim) → |
| 327 | + # safetensors k_proj (num_kv_heads*head_dim, embed_dim) and v_proj (num_kv_heads*head_dim, embed_dim) |
| 328 | + if hasattr(layer.attn, 'kv_einsum'): |
| 329 | + kv_w = np.array( |
| 330 | + layer.attn.kv_einsum.w.value |
| 331 | + ) # (2, num_kv_heads, embed_dim, head_dim) |
| 332 | + |
| 333 | + # Split into k and v |
| 334 | + k_w = kv_w[0] # (num_kv_heads, embed_dim, head_dim) |
| 335 | + v_w = kv_w[1] # (num_kv_heads, embed_dim, head_dim) |
| 336 | + |
| 337 | + # Transform k |
| 338 | + k_w = k_w.transpose(0, 2, 1) # (num_kv_heads, head_dim, embed_dim) |
| 339 | + k_w = k_w.reshape( |
| 340 | + -1, self.config.embed_dim |
| 341 | + ) # (num_kv_heads*head_dim, embed_dim) |
| 342 | + base_state[f'{prefix}.self_attn.k_proj.weight'] = k_w |
| 343 | + |
| 344 | + # Transform v |
| 345 | + v_w = v_w.transpose(0, 2, 1) # (num_kv_heads, head_dim, embed_dim) |
| 346 | + v_w = v_w.reshape( |
| 347 | + -1, self.config.embed_dim |
| 348 | + ) # (num_kv_heads*head_dim, embed_dim) |
| 349 | + base_state[f'{prefix}.self_attn.v_proj.weight'] = v_w |
| 350 | + |
| 351 | + # o_proj (attn_vec_einsum): nnx (num_heads, head_dim, embed_dim) → safetensors (embed_dim, num_heads*head_dim) |
| 352 | + if hasattr(layer.attn, 'attn_vec_einsum'): |
| 353 | + w = np.array( |
| 354 | + layer.attn.attn_vec_einsum.w.value |
| 355 | + ) # (num_heads, head_dim, embed_dim) |
| 356 | + w = w.reshape( |
| 357 | + -1, self.config.embed_dim |
| 358 | + ) # (num_heads*head_dim, embed_dim) |
| 359 | + base_state[f'{prefix}.self_attn.o_proj.weight'] = ( |
| 360 | + w.T |
| 361 | + ) # (embed_dim, num_heads*head_dim) |
| 362 | + |
| 363 | + # MLP projections |
| 364 | + # nnx: (in_features, out_features) → safetensors: (out_features, in_features) |
| 365 | + if hasattr(layer.mlp, 'gate_proj'): |
| 366 | + base_state[f'{prefix}.mlp.gate_proj.weight'] = np.array( |
| 367 | + layer.mlp.gate_proj.kernel.value |
| 368 | + ).T |
| 369 | + |
| 370 | + if hasattr(layer.mlp, 'up_proj'): |
| 371 | + base_state[f'{prefix}.mlp.up_proj.weight'] = np.array( |
| 372 | + layer.mlp.up_proj.kernel.value |
| 373 | + ).T |
| 374 | + |
| 375 | + if hasattr(layer.mlp, 'down_proj'): |
| 376 | + base_state[f'{prefix}.mlp.down_proj.weight'] = np.array( |
| 377 | + layer.mlp.down_proj.kernel.value |
| 378 | + ).T |
| 379 | + |
| 380 | + # Save to disk |
| 381 | + safe_np.save_file( |
| 382 | + base_state, os.path.join(self.base_checkpoint_dir, 'model.safetensors') |
| 383 | + ) |
| 384 | + |
| 385 | + # Minimal config for file copying test |
| 386 | + with open(os.path.join(self.base_checkpoint_dir, 'config.json'), 'w') as f: |
| 387 | + f.write('{"model_type": "gemma3"}') |
| 388 | + |
| 389 | + return self.base_checkpoint_dir |
| 390 | + |
| 391 | + def test_kv_einsum_decomposition(self): |
| 392 | + """Test that kv_einsum is properly decomposed into k_proj and v_proj.""" |
| 393 | + # Create base model and checkpoint |
| 394 | + base_model = self._create_base_model() |
| 395 | + self.create_checkpoint(base_model) |
| 396 | + |
| 397 | + # Apply LoRA |
| 398 | + lora_model = self._apply_lora_to_model(base_model) |
| 399 | + |
| 400 | + # Save merged model |
| 401 | + self.save_merged_model(lora_model) |
| 402 | + |
| 403 | + # Load the merged state |
| 404 | + merged_state = safe_np.load_file( |
| 405 | + os.path.join(self.merged_output_dir, 'model.safetensors') |
| 406 | + ) |
| 407 | + |
| 408 | + # Verify k_proj and v_proj exist (decomposed from kv_einsum) |
| 409 | + for layer_idx in range(self.config.num_layers): |
| 410 | + k_proj_key = f'model.layers.{layer_idx}.self_attn.k_proj.weight' |
| 411 | + v_proj_key = f'model.layers.{layer_idx}.self_attn.v_proj.weight' |
| 412 | + |
| 413 | + self.assertIn( |
| 414 | + k_proj_key, merged_state, f'Missing k_proj for layer {layer_idx}' |
| 415 | + ) |
| 416 | + self.assertIn( |
| 417 | + v_proj_key, merged_state, f'Missing v_proj for layer {layer_idx}' |
| 418 | + ) |
| 419 | + |
| 420 | + # Verify shapes |
| 421 | + expected_shape = ( |
| 422 | + self.config.num_kv_heads * self.config.head_dim, |
| 423 | + self.config.embed_dim, |
| 424 | + ) |
| 425 | + self.assertEqual( |
| 426 | + merged_state[k_proj_key].shape, |
| 427 | + expected_shape, |
| 428 | + f'Wrong shape for k_proj in layer {layer_idx}', |
| 429 | + ) |
| 430 | + self.assertEqual( |
| 431 | + merged_state[v_proj_key].shape, |
| 432 | + expected_shape, |
| 433 | + f'Wrong shape for v_proj in layer {layer_idx}', |
| 434 | + ) |
| 435 | + |
| 436 | + |
177 | 437 | if __name__ == "__main__": |
178 | 438 | # Check if running in Jupyter/IPython environment |
179 | | - try: |
180 | | - get_ipython() |
| 439 | + if test_common.is_running_in_colab(): |
181 | 440 | # Running in Jupyter/IPython - run tests directly to avoid SystemExit |
182 | | - suite = unittest.TestLoader().loadTestsFromTestCase(Llama3ParamsTest) |
| 441 | + suite = unittest.TestLoader().loadTestsFromTestCase(Gemma3LoraParamsTest) |
183 | 442 | runner = unittest.TextTestRunner(verbosity=2) |
184 | 443 | runner.run(suite) |
185 | | - except NameError: |
| 444 | + else: |
186 | 445 | # Running as a script - use absltest.main() |
187 | 446 | absltest.main() |
0 commit comments