Skip to content

Conversation

@sarathc-cerebras
Copy link

@sarathc-cerebras sarathc-cerebras commented Dec 7, 2025

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Member

Hi @sarathc-cerebras, thank you for the PR! The main thing missing is a conversion to modular format. You can look at the modular files for other models to see how it works, but it reduces the size of the PR a lot by importing duplicated code from other models.

@sarathc-cerebras
Copy link
Author

@Rocketknight1 thanks for bringing this up, i have updated it to use the modular format

@sarathc-cerebras sarathc-cerebras force-pushed the add-jais2-model branch 4 times, most recently from 2ae7204 to 672e38a Compare December 9, 2025 14:13
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks good! I made a few comments but they're small.

Comment on lines 168 to 179
class Jais2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can import this class too! We have a few other models that don't use gated linear units in the MLP. Maybe nemotron?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for suggesting.. imported for nemotron now 👍


## Overview

Jais2 is a large language model developed by MBZUAI, Inception and Cerebras Systems. It is based on the transformer architecture with several modifications including:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably mention that it's Arabic-focused here, right? That's one of the main selling points for jais / jais2!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have updated as requested

model = Jais2ForCausalLM.from_pretrained(
self.checkpoint,
device_map="auto",
torch_dtype=torch.float16,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests are float16 and some are bfloat16 - is this intended? If it's copied from another model then it's probably fine 😅

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed all to float16

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, I think we can still simplify a bit and update a few things to be up to date with our current standards. Overall, looking really good already tho

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need something in tokenization_auto as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the old init structure we had, can you take a look at Llama for example

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_llama import *
from .modeling_llama import *
from .tokenization_llama import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

Much simpler

Comment on lines +218 to +227
class Jais2ForSequenceClassification(LlamaForSequenceClassification):
pass


class Jais2ForQuestionAnswering(LlamaForQuestionAnswering):
pass


class Jais2ForTokenClassification(LlamaForTokenClassification):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to avoid additional classes unless we have a reason to include them

)


JAIS2_8B_CHECKPOINT = "inceptionai/Jais-2-8B-Chat"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would avoid having a constant here, let's just move the string directly

Comment on lines +83 to +92
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = (
"{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\n' }}{% endfor %}"
)

def tearDown(self):
backend_empty_cache(torch_device)
gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be enough instead:

def setup(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
# Investigate the root cause.
cleanup(torch_device, gc_collect=True)

We can load the tokenizer as well there


@slow
@require_torch_accelerator
def test_model_logits(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's reduce the amounts of tests to 2-3 tests, e.g. one fp16 logits test, and a generation test. No need to go over the board here.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, jais2

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42684&sha=5090c1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants