Skip to content

Commit a363e45

Browse files
addresses review comment
1 parent 7dfa45e commit a363e45

File tree

3 files changed

+6
-16
lines changed

3 files changed

+6
-16
lines changed

docs/source/en/model_doc/jais2.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ rendered properly in your Markdown viewer.
1919

2020
## Overview
2121

22-
Jais2 is a large language model developed by MBZUAI, Inception and Cerebras Systems. It is based on the transformer architecture with several modifications including:
22+
Jais2 a next-generation Arabic open-weight LLM trained on the richest Arabic-first dataset to date. Built from the ground up with 8B and 70B parameters, Jais 2 understands Arabic the way it's truly spoken across dialects, cuulutre, and modern expression. It is developed by MBZUAI, Inception and Cerebras Systems and based on the transformer architecture with modifications including:
2323

2424
- LayerNorm instead of RMSNorm
2525
- ReLU² activation function

src/transformers/models/jais2/modular_jais2.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch.nn as nn
1919

20-
from ...activations import ACT2FN
2120
from ...modeling_rope_utils import RopeParameters
2221
from ...utils import auto_docstring, can_return_tuple
2322
from ..llama.configuration_llama import LlamaConfig
@@ -30,6 +29,7 @@
3029
LlamaModel,
3130
LlamaPreTrainedModel,
3231
)
32+
from ..nemotron.modeling_nemotron import NemotronMLP
3333

3434

3535
class Jais2Config(LlamaConfig):
@@ -165,18 +165,8 @@ def __init__(
165165
__all__ = ["Jais2Config"]
166166

167167

168-
class Jais2MLP(nn.Module):
169-
def __init__(self, config):
170-
super().__init__()
171-
self.config = config
172-
self.hidden_size = config.hidden_size
173-
self.intermediate_size = config.intermediate_size
174-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
175-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
176-
self.act_fn = ACT2FN[config.hidden_act]
177-
178-
def forward(self, x):
179-
return self.down_proj(self.act_fn(self.up_proj(x)))
168+
class Jais2MLP(NemotronMLP):
169+
pass
180170

181171

182172
class Jais2DecoderLayer(LlamaDecoderLayer):

tests/models/jais2/test_modeling_jais2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ def test_model_logits(self):
128128
@slow
129129
@require_torch_accelerator
130130
def test_model_logits_bf16(self):
131-
"""Test model logits in bfloat16 precision."""
131+
"""Test model logits in float16 precision."""
132132
model = Jais2ForCausalLM.from_pretrained(
133133
self.checkpoint,
134134
device_map="auto",
135-
torch_dtype=torch.bfloat16,
135+
torch_dtype=torch.float16,
136136
)
137137

138138
input_text = "The capital of France is"

0 commit comments

Comments
 (0)