Skip to content

Commit 9ab4c8e

Browse files
ncylichThe tunix Authors
authored andcommitted
Copybara import of the project:
-- 16c7f6f by Noah Cylich <[email protected]>: Added LoRA saving and consolidate model parameter tests - Add proper LoRA adapter saving methods for Gemma3 and Qwen3 models in their params.py - Create SafetensorsSaver utility for standardized model weight serialization - Consolidate PEFT parameter tests with abstract base class (_lora_params_test_base.py) - Update example notebooks to use new saving methods with correct logic (the logic was slightly incorrect previously) - Fix import order and minor test adjustments COPYBARA_INTEGRATE_REVIEW=#744 from ncylich:fixed-gemma-lora-saving 16c7f6f PiperOrigin-RevId: 836961494
1 parent 9388533 commit 9ab4c8e

File tree

9 files changed

+19860
-18669
lines changed

9 files changed

+19860
-18669
lines changed

examples/dpo_gemma.ipynb

Lines changed: 6286 additions & 6321 deletions
Large diffs are not rendered by default.

examples/grpo_gemma.ipynb

Lines changed: 3012 additions & 2936 deletions
Large diffs are not rendered by default.

examples/qlora_gemma.ipynb

Lines changed: 9366 additions & 9402 deletions
Large diffs are not rendered by default.

tests/models/gemma_all/gemma_params_test.py

Lines changed: 269 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,20 @@
3030
# Tests are skipped if model paths are not configured.
3131
# END-GOOGLE-INTERNAL
3232

33+
import os
3334
import unittest
3435

36+
from absl.testing import absltest
3537
from absl.testing import parameterized
36-
from flax import nnx
37-
import jax
38-
from tunix.models.gemma3 import params as gemma3_params_lib
39-
import numpy as np
4038
from flax.traverse_util import flatten_dict
41-
39+
import jax.numpy as jnp
40+
import numpy as np
41+
import safetensors.numpy as safe_np
42+
from tunix.tests import lora_params_test_base
43+
from tunix.models.gemma3 import model as gemma3_model
44+
from tunix.models.gemma3 import params as gemma3_params
45+
from tunix.models.gemma3 import params_safetensors as gemma3_params_safetensors
46+
from tunix.tests import test_common
4247

4348
class GemmaParamsTest(parameterized.TestCase):
4449

@@ -90,7 +95,7 @@ def test_map_from_upstream_checkpoint(self, model_type):
9095
"transformer/layer_0/pre_ffw_norm": {"scale": pre_ffw},
9196
}
9297

93-
mapped = gemma3_params_lib.map_from_upstream_checkpoint(upstream, model_type)
98+
mapped = gemma3_params.map_from_upstream_checkpoint(upstream, model_type)
9499
flat_m = flatten_dict(mapped) # tuple keys
95100

96101
# --- Keys & shapes we expect after mapping (tiny) ---
@@ -174,14 +179,268 @@ def test_map_from_upstream_checkpoint(self, model_type):
174179
assert not any(k[0] == 'siglip_encoder' for k in flat_m.keys())
175180
assert ('embedder', 'mm_patch') not in mapped.get('embedder', {})
176181

182+
183+
class Gemma3LoraParamsTest(lora_params_test_base.LoraParamsTestBase):
184+
"""Tests for Gemma3 LoRA merged model saving and loading."""
185+
186+
def create_config(self):
187+
"""Create Gemma3 model config for testing."""
188+
return gemma3_model.ModelConfig(
189+
num_layers=2,
190+
num_embed=256,
191+
embed_dim=64,
192+
hidden_dim=128,
193+
num_heads=4,
194+
head_dim=16,
195+
num_kv_heads=1,
196+
sliding_window_size=128, # Required for LOCAL_SLIDING attention
197+
)
198+
199+
def get_model_class(self):
200+
"""Get Gemma3 model class."""
201+
return gemma3_model.Gemma3
202+
203+
def get_lora_module_path(self) -> str:
204+
"""Get LoRA target modules for Gemma3."""
205+
return '.*q_einsum|.*kv_einsum|.*attn_vec_einsum|.*gate_proj|.*up_proj|.*down_proj'
206+
207+
def get_projection_keys(self, layer_idx: int) -> list[str]:
208+
"""Get projection keys for Gemma3."""
209+
prefix = f'model.layers.{layer_idx}'
210+
return [
211+
f'{prefix}.self_attn.q_proj.weight',
212+
f'{prefix}.self_attn.k_proj.weight',
213+
f'{prefix}.self_attn.v_proj.weight',
214+
f'{prefix}.self_attn.o_proj.weight',
215+
f'{prefix}.mlp.gate_proj.weight',
216+
f'{prefix}.mlp.up_proj.weight',
217+
f'{prefix}.mlp.down_proj.weight',
218+
]
219+
220+
def save_merged_model(self, lora_model):
221+
"""Save Gemma3 LoRA merged model."""
222+
gemma3_params.save_lora_merged_model_as_safetensors(
223+
local_model_path=self.base_checkpoint_dir,
224+
output_dir=self.merged_output_dir,
225+
lora_model=lora_model,
226+
rank=self.rank,
227+
alpha=self.alpha,
228+
)
229+
230+
def create_model_from_checkpoint(self, checkpoint_dir: str):
231+
"""Load Gemma3 model from checkpoint."""
232+
return gemma3_params_safetensors.create_model_from_safe_tensors(
233+
file_dir=checkpoint_dir,
234+
config=self.config,
235+
mesh=None,
236+
dtype=jnp.float32,
237+
)
238+
239+
def _create_test_inputs(self):
240+
"""Create test inputs for Gemma3 forward pass."""
241+
batch_size = 2
242+
seq_len = 10
243+
244+
input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
245+
positions = jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0)
246+
# Gemma3 requires attention mask - create causal mask
247+
attention_mask = jnp.tril(jnp.ones((batch_size, seq_len, seq_len)))
248+
249+
return input_tokens, positions, attention_mask
250+
251+
def _run_forward_pass(self, model, input_tokens, positions, attention_mask):
252+
"""Run forward pass through Gemma3 model."""
253+
# Gemma3 uses `last_tokens` parameter name
254+
return model(
255+
last_tokens=input_tokens,
256+
positions=positions,
257+
cache=None,
258+
attention_mask=attention_mask,
259+
)
260+
261+
def create_checkpoint(self, model) -> str:
262+
"""Extract model weights and save in safetensors format.
263+
264+
Uses the model's actual weights and applies inverse transformations
265+
to create a valid safetensors file compatible with Gemma3 loader.
266+
267+
Key difference from Qwen3: kv_einsum must be decomposed into k_proj and
268+
v_proj.
269+
270+
Args:
271+
model: Base model to extract weights from.
272+
273+
Returns:
274+
Path to the created checkpoint directory.
275+
"""
276+
os.makedirs(self.base_checkpoint_dir, exist_ok=True)
277+
278+
base_state = {}
279+
280+
# Embedder (no transformation needed)
281+
base_state['model.embed_tokens.weight'] = np.array(
282+
model.embedder.input_embedding.value
283+
)
284+
285+
# Final norm (no transformation needed)
286+
base_state['model.norm.weight'] = np.array(model.final_norm.scale.value)
287+
288+
# Extract and transform weights for all layers
289+
for layer_idx, layer in enumerate(model.layers):
290+
prefix = f'model.layers.{layer_idx}'
291+
292+
# Layer norms (no transformation needed)
293+
base_state[f'{prefix}.input_layernorm.weight'] = np.array(
294+
layer.pre_attention_norm.scale.value
295+
)
296+
base_state[f'{prefix}.post_attention_layernorm.weight'] = np.array(
297+
layer.post_attention_norm.scale.value
298+
)
299+
base_state[f'{prefix}.pre_feedforward_layernorm.weight'] = np.array(
300+
layer.pre_ffw_norm.scale.value
301+
)
302+
base_state[f'{prefix}.post_feedforward_layernorm.weight'] = np.array(
303+
layer.post_ffw_norm.scale.value
304+
)
305+
306+
# Query/Key norms (no transformation needed)
307+
base_state[f'{prefix}.self_attn.q_norm.weight'] = np.array(
308+
layer.attn._query_norm.scale.value
309+
)
310+
base_state[f'{prefix}.self_attn.k_norm.weight'] = np.array(
311+
layer.attn._key_norm.scale.value
312+
)
313+
314+
# Attention projections
315+
# q_einsum: nnx (num_heads, embed_dim, head_dim) → safetensors (num_heads*head_dim, embed_dim)
316+
if hasattr(layer.attn, 'q_einsum'):
317+
w = np.array(
318+
layer.attn.q_einsum.w.value
319+
) # (num_heads, embed_dim, head_dim)
320+
w = w.transpose(0, 2, 1) # (num_heads, head_dim, embed_dim)
321+
w = w.reshape(
322+
-1, self.config.embed_dim
323+
) # (num_heads*head_dim, embed_dim)
324+
base_state[f'{prefix}.self_attn.q_proj.weight'] = w
325+
326+
# kv_einsum: nnx (2, num_kv_heads, embed_dim, head_dim) →
327+
# safetensors k_proj (num_kv_heads*head_dim, embed_dim) and v_proj (num_kv_heads*head_dim, embed_dim)
328+
if hasattr(layer.attn, 'kv_einsum'):
329+
kv_w = np.array(
330+
layer.attn.kv_einsum.w.value
331+
) # (2, num_kv_heads, embed_dim, head_dim)
332+
333+
# Split into k and v
334+
k_w = kv_w[0] # (num_kv_heads, embed_dim, head_dim)
335+
v_w = kv_w[1] # (num_kv_heads, embed_dim, head_dim)
336+
337+
# Transform k
338+
k_w = k_w.transpose(0, 2, 1) # (num_kv_heads, head_dim, embed_dim)
339+
k_w = k_w.reshape(
340+
-1, self.config.embed_dim
341+
) # (num_kv_heads*head_dim, embed_dim)
342+
base_state[f'{prefix}.self_attn.k_proj.weight'] = k_w
343+
344+
# Transform v
345+
v_w = v_w.transpose(0, 2, 1) # (num_kv_heads, head_dim, embed_dim)
346+
v_w = v_w.reshape(
347+
-1, self.config.embed_dim
348+
) # (num_kv_heads*head_dim, embed_dim)
349+
base_state[f'{prefix}.self_attn.v_proj.weight'] = v_w
350+
351+
# o_proj (attn_vec_einsum): nnx (num_heads, head_dim, embed_dim) → safetensors (embed_dim, num_heads*head_dim)
352+
if hasattr(layer.attn, 'attn_vec_einsum'):
353+
w = np.array(
354+
layer.attn.attn_vec_einsum.w.value
355+
) # (num_heads, head_dim, embed_dim)
356+
w = w.reshape(
357+
-1, self.config.embed_dim
358+
) # (num_heads*head_dim, embed_dim)
359+
base_state[f'{prefix}.self_attn.o_proj.weight'] = (
360+
w.T
361+
) # (embed_dim, num_heads*head_dim)
362+
363+
# MLP projections
364+
# nnx: (in_features, out_features) → safetensors: (out_features, in_features)
365+
if hasattr(layer.mlp, 'gate_proj'):
366+
base_state[f'{prefix}.mlp.gate_proj.weight'] = np.array(
367+
layer.mlp.gate_proj.kernel.value
368+
).T
369+
370+
if hasattr(layer.mlp, 'up_proj'):
371+
base_state[f'{prefix}.mlp.up_proj.weight'] = np.array(
372+
layer.mlp.up_proj.kernel.value
373+
).T
374+
375+
if hasattr(layer.mlp, 'down_proj'):
376+
base_state[f'{prefix}.mlp.down_proj.weight'] = np.array(
377+
layer.mlp.down_proj.kernel.value
378+
).T
379+
380+
# Save to disk
381+
safe_np.save_file(
382+
base_state, os.path.join(self.base_checkpoint_dir, 'model.safetensors')
383+
)
384+
385+
# Minimal config for file copying test
386+
with open(os.path.join(self.base_checkpoint_dir, 'config.json'), 'w') as f:
387+
f.write('{"model_type": "gemma3"}')
388+
389+
return self.base_checkpoint_dir
390+
391+
def test_kv_einsum_decomposition(self):
392+
"""Test that kv_einsum is properly decomposed into k_proj and v_proj."""
393+
# Create base model and checkpoint
394+
base_model = self._create_base_model()
395+
self.create_checkpoint(base_model)
396+
397+
# Apply LoRA
398+
lora_model = self._apply_lora_to_model(base_model)
399+
400+
# Save merged model
401+
self.save_merged_model(lora_model)
402+
403+
# Load the merged state
404+
merged_state = safe_np.load_file(
405+
os.path.join(self.merged_output_dir, 'model.safetensors')
406+
)
407+
408+
# Verify k_proj and v_proj exist (decomposed from kv_einsum)
409+
for layer_idx in range(self.config.num_layers):
410+
k_proj_key = f'model.layers.{layer_idx}.self_attn.k_proj.weight'
411+
v_proj_key = f'model.layers.{layer_idx}.self_attn.v_proj.weight'
412+
413+
self.assertIn(
414+
k_proj_key, merged_state, f'Missing k_proj for layer {layer_idx}'
415+
)
416+
self.assertIn(
417+
v_proj_key, merged_state, f'Missing v_proj for layer {layer_idx}'
418+
)
419+
420+
# Verify shapes
421+
expected_shape = (
422+
self.config.num_kv_heads * self.config.head_dim,
423+
self.config.embed_dim,
424+
)
425+
self.assertEqual(
426+
merged_state[k_proj_key].shape,
427+
expected_shape,
428+
f'Wrong shape for k_proj in layer {layer_idx}',
429+
)
430+
self.assertEqual(
431+
merged_state[v_proj_key].shape,
432+
expected_shape,
433+
f'Wrong shape for v_proj in layer {layer_idx}',
434+
)
435+
436+
177437
if __name__ == "__main__":
178438
# Check if running in Jupyter/IPython environment
179-
try:
180-
get_ipython()
439+
if test_common.is_running_in_colab():
181440
# Running in Jupyter/IPython - run tests directly to avoid SystemExit
182-
suite = unittest.TestLoader().loadTestsFromTestCase(Llama3ParamsTest)
441+
suite = unittest.TestLoader().loadTestsFromTestCase(Gemma3LoraParamsTest)
183442
runner = unittest.TextTestRunner(verbosity=2)
184443
runner.run(suite)
185-
except NameError:
444+
else:
186445
# Running as a script - use absltest.main()
187446
absltest.main()

0 commit comments

Comments
 (0)