-
Notifications
You must be signed in to change notification settings - Fork 31.4k
adds jais2 model support #42684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adds jais2 model support #42684
Conversation
|
Hi @sarathc-cerebras, thank you for the PR! The main thing missing is a conversion to modular format. You can look at the modular files for other models to see how it works, but it reduces the size of the PR a lot by importing duplicated code from other models. |
377e2b8 to
ab785fc
Compare
|
@Rocketknight1 thanks for bringing this up, i have updated it to use the modular format |
2ae7204 to
672e38a
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
9e0839b to
7dfa45e
Compare
Rocketknight1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this looks good! I made a few comments but they're small.
| class Jais2MLP(nn.Module): | ||
| def __init__(self, config): | ||
| super().__init__() | ||
| self.config = config | ||
| self.hidden_size = config.hidden_size | ||
| self.intermediate_size = config.intermediate_size | ||
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) | ||
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) | ||
| self.act_fn = ACT2FN[config.hidden_act] | ||
|
|
||
| def forward(self, x): | ||
| return self.down_proj(self.act_fn(self.up_proj(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can import this class too! We have a few other models that don't use gated linear units in the MLP. Maybe nemotron?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for suggesting.. imported for nemotron now 👍
docs/source/en/model_doc/jais2.md
Outdated
|
|
||
| ## Overview | ||
|
|
||
| Jais2 is a large language model developed by MBZUAI, Inception and Cerebras Systems. It is based on the transformer architecture with several modifications including: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably mention that it's Arabic-focused here, right? That's one of the main selling points for jais / jais2!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have updated as requested
| model = Jais2ForCausalLM.from_pretrained( | ||
| self.checkpoint, | ||
| device_map="auto", | ||
| torch_dtype=torch.float16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tests are float16 and some are bfloat16 - is this intended? If it's copied from another model then it's probably fine 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed all to float16
a363e45 to
e363470
Compare
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, I think we can still simplify a bit and update a few things to be up to date with our current standards. Overall, looking really good already tho
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need something in tokenization_auto as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the old init structure we had, can you take a look at Llama for example
transformers/src/transformers/models/llama/__init__.py
Lines 1 to 28 in b9951b4
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule | |
| from ...utils.import_utils import define_import_structure | |
| if TYPE_CHECKING: | |
| from .configuration_llama import * | |
| from .modeling_llama import * | |
| from .tokenization_llama import * | |
| else: | |
| import sys | |
| _file = globals()["__file__"] | |
| sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Much simpler
| class Jais2ForSequenceClassification(LlamaForSequenceClassification): | ||
| pass | ||
|
|
||
|
|
||
| class Jais2ForQuestionAnswering(LlamaForQuestionAnswering): | ||
| pass | ||
|
|
||
|
|
||
| class Jais2ForTokenClassification(LlamaForTokenClassification): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to avoid additional classes unless we have a reason to include them
| ) | ||
|
|
||
|
|
||
| JAIS2_8B_CHECKPOINT = "inceptionai/Jais-2-8B-Chat" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would avoid having a constant here, let's just move the string directly
| def setUp(self): | ||
| self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) | ||
| if self.tokenizer.chat_template is None: | ||
| self.tokenizer.chat_template = ( | ||
| "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\n' }}{% endfor %}" | ||
| ) | ||
|
|
||
| def tearDown(self): | ||
| backend_empty_cache(torch_device) | ||
| gc.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be enough instead:
transformers/tests/models/llama/test_modeling_llama.py
Lines 67 to 75 in b9951b4
| def setup(self): | |
| cleanup(torch_device, gc_collect=True) | |
| def tearDown(self): | |
| # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves | |
| # some memory allocated in the cache, which means some object is not being released properly. This causes some | |
| # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. | |
| # Investigate the root cause. | |
| cleanup(torch_device, gc_collect=True) |
We can load the tokenizer as well there
|
|
||
| @slow | ||
| @require_torch_accelerator | ||
| def test_model_logits(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's reduce the amounts of tests to 2-3 tests, e.g. one fp16 logits test, and a generation test. No need to go over the board here.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, jais2 |
2f9713c to
5090c18
Compare
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42684&sha=5090c1 |
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.