Add-helium (#35669)
* Add the helium model. * Add a missing helium. * And add another missing helium. * Use float for the rmsnorm mul. * Add the Helium tokenizer converter. * Add the pad token as suggested by Arthur. * Update the RMSNorm + some other tweaks. * Fix more rebase issues. * fix copies and style * fixes and add helium.md * add missing tests * udpate the backlink * oups * style * update init, and expected results * small fixes * match test outputs * style fixup, fix doc builder * add dummies and we should be good to go!z * update sdpa and fa2 documentation --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@@ -452,6 +452,8 @@
|
|||||||
title: Granite
|
title: Granite
|
||||||
- local: model_doc/granitemoe
|
- local: model_doc/granitemoe
|
||||||
title: GraniteMoe
|
title: GraniteMoe
|
||||||
|
- local: model_doc/helium
|
||||||
|
title: Helium
|
||||||
- local: model_doc/herbert
|
- local: model_doc/herbert
|
||||||
title: HerBERT
|
title: HerBERT
|
||||||
- local: model_doc/ibert
|
- local: model_doc/ibert
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [Graphormer](model_doc/graphormer) | ✅ | ❌ | ❌ |
|
| [Graphormer](model_doc/graphormer) | ✅ | ❌ | ❌ |
|
||||||
| [Grounding DINO](model_doc/grounding-dino) | ✅ | ❌ | ❌ |
|
| [Grounding DINO](model_doc/grounding-dino) | ✅ | ❌ | ❌ |
|
||||||
| [GroupViT](model_doc/groupvit) | ✅ | ✅ | ❌ |
|
| [GroupViT](model_doc/groupvit) | ✅ | ✅ | ❌ |
|
||||||
|
| [Helium](model_doc/helium) | ✅ | ❌ | ❌ |
|
||||||
| [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ |
|
| [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ |
|
||||||
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
|
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
|
||||||
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
|
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
|
||||||
|
|||||||
158
docs/source/en/model_doc/helium.md
Normal file
158
docs/source/en/model_doc/helium.md
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
<!--Copyright 2024 Kyutai and The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Helium
|
||||||
|
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Helium was proposed in [Announcing Helium-1 Preview](https://kyutai.org/2025/01/13/helium.html) by the Kyutai Team.
|
||||||
|
|
||||||
|
|
||||||
|
Helium-1 preview is a lightweight language model with 2B parameters, targeting edge and mobile devices.
|
||||||
|
It supports the following languages: English, French, German, Italian, Portuguese, Spanish.
|
||||||
|
|
||||||
|
- **Developed by:** Kyutai
|
||||||
|
- **Model type:** Large Language Model
|
||||||
|
- **Language(s) (NLP):** English, French, German, Italian, Portuguese, Spanish
|
||||||
|
- **License:** CC-BY 4.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
<!-- This section describes the evaluation protocols and provides the results. -->
|
||||||
|
|
||||||
|
#### Testing Data
|
||||||
|
|
||||||
|
<!-- This should link to a Dataset Card if possible. -->
|
||||||
|
|
||||||
|
The model was evaluated on MMLU, TriviaQA, NaturalQuestions, ARC Easy & Challenge, Open Book QA, Common Sense QA,
|
||||||
|
Physical Interaction QA, Social Interaction QA, HellaSwag, WinoGrande, Multilingual Knowledge QA, FLORES 200.
|
||||||
|
|
||||||
|
#### Metrics
|
||||||
|
|
||||||
|
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
||||||
|
|
||||||
|
We report accuracy on MMLU, ARC, OBQA, CSQA, PIQA, SIQA, HellaSwag, WinoGrande.
|
||||||
|
We report exact match on TriviaQA, NQ and MKQA.
|
||||||
|
We report BLEU on FLORES.
|
||||||
|
|
||||||
|
### English Results
|
||||||
|
|
||||||
|
| Benchmark | Helium-1 Preview | HF SmolLM2 (1.7B) | Gemma-2 (2.6B) | Llama-3.2 (3B) | Qwen2.5 (1.5B) |
|
||||||
|
|--------------|--------|--------|--------|--------|--------|
|
||||||
|
| | | | | | |
|
||||||
|
| MMLU | 51.2 | 50.4 | 53.1 | 56.6 | 61.0 |
|
||||||
|
| NQ | 17.3 | 15.1 | 17.7 | 22.0 | 13.1 |
|
||||||
|
| TQA | 47.9 | 45.4 | 49.9 | 53.6 | 35.9 |
|
||||||
|
| ARC E | 80.9 | 81.8 | 81.1 | 84.6 | 89.7 |
|
||||||
|
| ARC C | 62.7 | 64.7 | 66.0 | 69.0 | 77.2 |
|
||||||
|
| OBQA | 63.8 | 61.4 | 64.6 | 68.4 | 73.8 |
|
||||||
|
| CSQA | 65.6 | 59.0 | 64.4 | 65.4 | 72.4 |
|
||||||
|
| PIQA | 77.4 | 77.7 | 79.8 | 78.9 | 76.0 |
|
||||||
|
| SIQA | 64.4 | 57.5 | 61.9 | 63.8 | 68.7 |
|
||||||
|
| HS | 69.7 | 73.2 | 74.7 | 76.9 | 67.5 |
|
||||||
|
| WG | 66.5 | 65.6 | 71.2 | 72.0 | 64.8 |
|
||||||
|
| | | | | | |
|
||||||
|
| Average | 60.7 | 59.3 | 62.2 | 64.7 | 63.6 |
|
||||||
|
|
||||||
|
#### Multilingual Results
|
||||||
|
|
||||||
|
| Language | Benchmark | Helium-1 Preview | HF SmolLM2 (1.7B) | Gemma-2 (2.6B) | Llama-3.2 (3B) | Qwen2.5 (1.5B) |
|
||||||
|
|-----|--------------|--------|--------|--------|--------|--------|
|
||||||
|
| | | | | | | |
|
||||||
|
|German| MMLU | 45.6 | 35.3 | 45.0 | 47.5 | 49.5 |
|
||||||
|
|| ARC C | 56.7 | 38.4 | 54.7 | 58.3 | 60.2 |
|
||||||
|
|| HS | 53.5 | 33.9 | 53.4 | 53.7 | 42.8 |
|
||||||
|
|| MKQA | 16.1 | 7.1 | 18.9 | 20.2 | 10.4 |
|
||||||
|
| | | | | | | |
|
||||||
|
|Spanish| MMLU | 46.5 | 38.9 | 46.2 | 49.6 | 52.8 |
|
||||||
|
|| ARC C | 58.3 | 43.2 | 58.8 | 60.0 | 68.1 |
|
||||||
|
|| HS | 58.6 | 40.8 | 60.5 | 61.1 | 51.4 |
|
||||||
|
|| MKQA | 16.0 | 7.9 | 18.5 | 20.6 | 10.6 |
|
||||||
|
|
||||||
|
|
||||||
|
## Technical Specifications
|
||||||
|
|
||||||
|
### Model Architecture and Objective
|
||||||
|
|
||||||
|
| Hyperparameter | Value |
|
||||||
|
|--------------|--------|
|
||||||
|
| Layers | 24 |
|
||||||
|
| Heads | 20 |
|
||||||
|
| Model dimension | 2560 |
|
||||||
|
| MLP dimension | 7040 |
|
||||||
|
| Context size | 4096 |
|
||||||
|
| Theta RoPE | 100,000 |
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- This model was contributed by [Laurent Mazare](https://huggingface.co/lmz)
|
||||||
|
|
||||||
|
|
||||||
|
## Usage tips
|
||||||
|
|
||||||
|
`Helium` can be found on the [Huggingface Hub](https://huggingface.co/collections/kyutai/helium-1-preview)
|
||||||
|
|
||||||
|
In the following, we demonstrate how to use `helium-1-preview` for the inference.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
>>> device = "cuda" # the device to load the model onto
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("helium-1-preview", device_map="auto")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("helium-1-preview")
|
||||||
|
|
||||||
|
>>> prompt = "Give me a short introduction to large language model."
|
||||||
|
|
||||||
|
>>> messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
|
||||||
|
>>> model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True)
|
||||||
|
|
||||||
|
>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
|
||||||
|
|
||||||
|
>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
## HeliumConfig
|
||||||
|
|
||||||
|
[[autodoc]] HeliumConfig
|
||||||
|
|
||||||
|
## HeliumModel
|
||||||
|
|
||||||
|
[[autodoc]] HeliumModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## HeliumForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] HeliumForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## HeliumForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] HeliumForSequenceClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## HeliumForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] HeliumForTokenClassification
|
||||||
|
- forward
|
||||||
@@ -109,6 +109,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||||
|
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
||||||
|
|
||||||
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
||||||
|
|
||||||
@@ -324,6 +325,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
|
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
|
||||||
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
||||||
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
||||||
|
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
|
|||||||
@@ -498,6 +498,7 @@ _import_structure = {
|
|||||||
"GroupViTTextConfig",
|
"GroupViTTextConfig",
|
||||||
"GroupViTVisionConfig",
|
"GroupViTVisionConfig",
|
||||||
],
|
],
|
||||||
|
"models.helium": ["HeliumConfig"],
|
||||||
"models.herbert": ["HerbertTokenizer"],
|
"models.herbert": ["HerbertTokenizer"],
|
||||||
"models.hiera": ["HieraConfig"],
|
"models.hiera": ["HieraConfig"],
|
||||||
"models.hubert": ["HubertConfig"],
|
"models.hubert": ["HubertConfig"],
|
||||||
@@ -2506,6 +2507,15 @@ else:
|
|||||||
"GroupViTVisionModel",
|
"GroupViTVisionModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.helium"].extend(
|
||||||
|
[
|
||||||
|
"HeliumForCausalLM",
|
||||||
|
"HeliumForSequenceClassification",
|
||||||
|
"HeliumForTokenClassification",
|
||||||
|
"HeliumModel",
|
||||||
|
"HeliumPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.hiera"].extend(
|
_import_structure["models.hiera"].extend(
|
||||||
[
|
[
|
||||||
"HieraBackbone",
|
"HieraBackbone",
|
||||||
@@ -5529,6 +5539,7 @@ if TYPE_CHECKING:
|
|||||||
GroupViTTextConfig,
|
GroupViTTextConfig,
|
||||||
GroupViTVisionConfig,
|
GroupViTVisionConfig,
|
||||||
)
|
)
|
||||||
|
from .models.helium import HeliumConfig
|
||||||
from .models.herbert import HerbertTokenizer
|
from .models.herbert import HerbertTokenizer
|
||||||
from .models.hiera import HieraConfig
|
from .models.hiera import HieraConfig
|
||||||
from .models.hubert import HubertConfig
|
from .models.hubert import HubertConfig
|
||||||
@@ -7371,6 +7382,13 @@ if TYPE_CHECKING:
|
|||||||
GroupViTTextModel,
|
GroupViTTextModel,
|
||||||
GroupViTVisionModel,
|
GroupViTVisionModel,
|
||||||
)
|
)
|
||||||
|
from .models.helium import (
|
||||||
|
HeliumForCausalLM,
|
||||||
|
HeliumForSequenceClassification,
|
||||||
|
HeliumForTokenClassification,
|
||||||
|
HeliumModel,
|
||||||
|
HeliumPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.hiera import (
|
from .models.hiera import (
|
||||||
HieraBackbone,
|
HieraBackbone,
|
||||||
HieraForImageClassification,
|
HieraForImageClassification,
|
||||||
|
|||||||
@@ -1446,6 +1446,95 @@ class MoshiConverter(SpmConverter):
|
|||||||
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
|
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumConverter(SpmConverter):
|
||||||
|
handle_byte_fallback = True
|
||||||
|
|
||||||
|
def __init__(self, vocab_file=None, *args):
|
||||||
|
requires_backends(self, "protobuf")
|
||||||
|
|
||||||
|
Converter.__init__(self, vocab_file)
|
||||||
|
|
||||||
|
model_pb2 = import_protobuf()
|
||||||
|
|
||||||
|
m = model_pb2.ModelProto()
|
||||||
|
with open(vocab_file, "rb") as f:
|
||||||
|
m.ParseFromString(f.read())
|
||||||
|
self.proto = m
|
||||||
|
|
||||||
|
def tokenizer(self, proto):
|
||||||
|
vocab_scores = self.vocab(proto)
|
||||||
|
tokenizer = Tokenizer(
|
||||||
|
Unigram(
|
||||||
|
vocab_scores,
|
||||||
|
unk_id=self.unk_id(proto),
|
||||||
|
byte_fallback=self.handle_byte_fallback,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# control tokens are special
|
||||||
|
# user defined symbols are not
|
||||||
|
# both user and control tokens are AddedTokens
|
||||||
|
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
||||||
|
spm_added_tokens = [
|
||||||
|
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
|
||||||
|
for id, p in enumerate(proto.pieces)
|
||||||
|
if p.type in [3, 4]
|
||||||
|
]
|
||||||
|
tokenizer.add_tokens(
|
||||||
|
[
|
||||||
|
AddedToken(token, normalized=False, special=special, single_word=True)
|
||||||
|
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tokenizer.add_tokens([AddedToken("\n", normalized=False, special=False)])
|
||||||
|
tokenizer.enable_padding(pad_token="<pad>", pad_id=3)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def vocab(self, proto):
|
||||||
|
vocab = []
|
||||||
|
for piece in proto.pieces:
|
||||||
|
if piece.piece == "<0x0A>":
|
||||||
|
vocab += [("\n", piece.score)]
|
||||||
|
else:
|
||||||
|
vocab += [(piece.piece, piece.score)]
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def unk_id(self, proto):
|
||||||
|
unk_id = 0
|
||||||
|
return unk_id
|
||||||
|
|
||||||
|
def decoder(self, replacement, add_prefix_space):
|
||||||
|
sequence = [
|
||||||
|
decoders.Replace("▁", " "),
|
||||||
|
decoders.ByteFallback(),
|
||||||
|
decoders.Fuse(),
|
||||||
|
]
|
||||||
|
sequence += [decoders.Strip(content=" ", left=1)]
|
||||||
|
return decoders.Sequence(sequence)
|
||||||
|
|
||||||
|
def normalizer(self, proto):
|
||||||
|
return normalizers.Sequence([normalizers.Prepend(" "), normalizers.Replace(r" ", "▁")])
|
||||||
|
|
||||||
|
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||||
|
return pre_tokenizers.Sequence([pre_tokenizers.Split("\n", "contiguous")])
|
||||||
|
|
||||||
|
def post_processor(self):
|
||||||
|
return processors.TemplateProcessing(
|
||||||
|
single=[
|
||||||
|
"<s>",
|
||||||
|
"$A",
|
||||||
|
],
|
||||||
|
pair=[
|
||||||
|
"<s>",
|
||||||
|
"$A",
|
||||||
|
"<s>",
|
||||||
|
"$B",
|
||||||
|
],
|
||||||
|
special_tokens=[
|
||||||
|
("<s>", 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
||||||
def bytes_to_unicode():
|
def bytes_to_unicode():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ from . import (
|
|||||||
granitemoe,
|
granitemoe,
|
||||||
grounding_dino,
|
grounding_dino,
|
||||||
groupvit,
|
groupvit,
|
||||||
|
helium,
|
||||||
herbert,
|
herbert,
|
||||||
hiera,
|
hiera,
|
||||||
hubert,
|
hubert,
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("graphormer", "GraphormerConfig"),
|
("graphormer", "GraphormerConfig"),
|
||||||
("grounding-dino", "GroundingDinoConfig"),
|
("grounding-dino", "GroundingDinoConfig"),
|
||||||
("groupvit", "GroupViTConfig"),
|
("groupvit", "GroupViTConfig"),
|
||||||
|
("helium", "HeliumConfig"),
|
||||||
("hiera", "HieraConfig"),
|
("hiera", "HieraConfig"),
|
||||||
("hubert", "HubertConfig"),
|
("hubert", "HubertConfig"),
|
||||||
("ibert", "IBertConfig"),
|
("ibert", "IBertConfig"),
|
||||||
@@ -458,6 +459,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("graphormer", "Graphormer"),
|
("graphormer", "Graphormer"),
|
||||||
("grounding-dino", "Grounding DINO"),
|
("grounding-dino", "Grounding DINO"),
|
||||||
("groupvit", "GroupViT"),
|
("groupvit", "GroupViT"),
|
||||||
|
("helium", "Helium"),
|
||||||
("herbert", "HerBERT"),
|
("herbert", "HerBERT"),
|
||||||
("hiera", "Hiera"),
|
("hiera", "Hiera"),
|
||||||
("hubert", "Hubert"),
|
("hubert", "Hubert"),
|
||||||
|
|||||||
@@ -132,6 +132,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("graphormer", "GraphormerModel"),
|
("graphormer", "GraphormerModel"),
|
||||||
("grounding-dino", "GroundingDinoModel"),
|
("grounding-dino", "GroundingDinoModel"),
|
||||||
("groupvit", "GroupViTModel"),
|
("groupvit", "GroupViTModel"),
|
||||||
|
("helium", "HeliumModel"),
|
||||||
("hiera", "HieraModel"),
|
("hiera", "HieraModel"),
|
||||||
("hubert", "HubertModel"),
|
("hubert", "HubertModel"),
|
||||||
("ibert", "IBertModel"),
|
("ibert", "IBertModel"),
|
||||||
@@ -517,6 +518,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("gptj", "GPTJForCausalLM"),
|
("gptj", "GPTJForCausalLM"),
|
||||||
("granite", "GraniteForCausalLM"),
|
("granite", "GraniteForCausalLM"),
|
||||||
("granitemoe", "GraniteMoeForCausalLM"),
|
("granitemoe", "GraniteMoeForCausalLM"),
|
||||||
|
("helium", "HeliumForCausalLM"),
|
||||||
("jamba", "JambaForCausalLM"),
|
("jamba", "JambaForCausalLM"),
|
||||||
("jetmoe", "JetMoeForCausalLM"),
|
("jetmoe", "JetMoeForCausalLM"),
|
||||||
("llama", "LlamaForCausalLM"),
|
("llama", "LlamaForCausalLM"),
|
||||||
@@ -989,6 +991,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("gpt_neo", "GPTNeoForSequenceClassification"),
|
("gpt_neo", "GPTNeoForSequenceClassification"),
|
||||||
("gpt_neox", "GPTNeoXForSequenceClassification"),
|
("gpt_neox", "GPTNeoXForSequenceClassification"),
|
||||||
("gptj", "GPTJForSequenceClassification"),
|
("gptj", "GPTJForSequenceClassification"),
|
||||||
|
("helium", "HeliumForSequenceClassification"),
|
||||||
("ibert", "IBertForSequenceClassification"),
|
("ibert", "IBertForSequenceClassification"),
|
||||||
("jamba", "JambaForSequenceClassification"),
|
("jamba", "JambaForSequenceClassification"),
|
||||||
("jetmoe", "JetMoeForSequenceClassification"),
|
("jetmoe", "JetMoeForSequenceClassification"),
|
||||||
@@ -1182,6 +1185,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
||||||
("gpt_neo", "GPTNeoForTokenClassification"),
|
("gpt_neo", "GPTNeoForTokenClassification"),
|
||||||
("gpt_neox", "GPTNeoXForTokenClassification"),
|
("gpt_neox", "GPTNeoXForTokenClassification"),
|
||||||
|
("helium", "HeliumForTokenClassification"),
|
||||||
("ibert", "IBertForTokenClassification"),
|
("ibert", "IBertForTokenClassification"),
|
||||||
("layoutlm", "LayoutLMForTokenClassification"),
|
("layoutlm", "LayoutLMForTokenClassification"),
|
||||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ else:
|
|||||||
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
||||||
("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
|
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
||||||
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
|||||||
27
src/transformers/models/helium/__init__.py
Normal file
27
src/transformers/models/helium/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_helium import *
|
||||||
|
from .modeling_helium import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
140
src/transformers/models/helium/configuration_helium.py
Normal file
140
src/transformers/models/helium/configuration_helium.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`HeliumModel`]. It is used to instantiate an Helium
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Helium 2b model.
|
||||||
|
e.g. [kyutai/helium-2b](https://huggingface.co/kyutai/helium-2b)
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 48000):
|
||||||
|
Vocabulary size of the Helium model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`HeliumModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2560):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 7040):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 20):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 20):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
The attention head dimension.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-08):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 3):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int` | `list`, *optional*, defaults to 2):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
Beginning of stream token id.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||||
|
```python
|
||||||
|
>>> from transformers import HeliumModel, HeliumConfig
|
||||||
|
>>> # Initializing a Helium 2b style configuration
|
||||||
|
>>> configuration = HeliumConfig()
|
||||||
|
>>> # Initializing a model from the Helium 2b style configuration
|
||||||
|
>>> model = HeliumModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "helium"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=48000,
|
||||||
|
hidden_size=2560,
|
||||||
|
intermediate_size=7040,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=20,
|
||||||
|
num_key_value_heads=20,
|
||||||
|
head_dim=128,
|
||||||
|
hidden_act="silu",
|
||||||
|
attention_dropout=0.0,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-8,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=100000.0,
|
||||||
|
pad_token_id=3,
|
||||||
|
eos_token_id=2,
|
||||||
|
bos_token_id=1,
|
||||||
|
attention_bias=False,
|
||||||
|
mlp_bias=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["HeliumConfig"]
|
||||||
1065
src/transformers/models/helium/modeling_helium.py
Normal file
1065
src/transformers/models/helium/modeling_helium.py
Normal file
File diff suppressed because it is too large
Load Diff
171
src/transformers/models/helium/modular_helium.py
Normal file
171
src/transformers/models/helium/modular_helium.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
from ..gemma.modeling_gemma import (
|
||||||
|
GemmaForCausalLM,
|
||||||
|
GemmaForSequenceClassification,
|
||||||
|
GemmaForTokenClassification,
|
||||||
|
)
|
||||||
|
from ..granite.modeling_granite import (
|
||||||
|
GraniteAttention,
|
||||||
|
)
|
||||||
|
from ..llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaMLP,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from .configuration_helium import HeliumConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumMLP(LlamaMLP):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., 0::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
|
||||||
|
# Interleave them instead of usual shape
|
||||||
|
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
||||||
|
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumAttention(GraniteAttention):
|
||||||
|
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||||
|
self.scaling = 1 / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumDecoderLayer(LlamaDecoderLayer):
|
||||||
|
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mlp = HeliumMLP(config)
|
||||||
|
self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumPreTrainedModel(LlamaPreTrainedModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumModel(HeliumPreTrainedModel, LlamaModel):
|
||||||
|
def __init__(self, config: HeliumConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.rotary_emb = HeliumRotaryEmbedding(config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumForCausalLM(GemmaForCausalLM):
|
||||||
|
def __init__(self, config: HeliumConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = HeliumModel(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumForSequenceClassification(GemmaForSequenceClassification):
|
||||||
|
def __init__(self, config: HeliumConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = HeliumModel(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumForTokenClassification(GemmaForTokenClassification):
|
||||||
|
def __init__(self, config: HeliumConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = HeliumModel(config)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HeliumPreTrainedModel",
|
||||||
|
"HeliumModel",
|
||||||
|
"HeliumForCausalLM",
|
||||||
|
"HeliumForSequenceClassification",
|
||||||
|
"HeliumForTokenClassification",
|
||||||
|
]
|
||||||
@@ -4981,6 +4981,41 @@ class GroupViTVisionModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumForCausalLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class HieraBackbone(metaclass=DummyObject):
|
class HieraBackbone(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
0
tests/models/helium/__init__.py
Normal file
0
tests/models/helium/__init__.py
Normal file
110
tests/models/helium/test_modeling_helium.py
Normal file
110
tests/models/helium/test_modeling_helium.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch Helium model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HeliumConfig, is_torch_available
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_read_token,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ..gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
HeliumForCausalLM,
|
||||||
|
HeliumForSequenceClassification,
|
||||||
|
HeliumForTokenClassification,
|
||||||
|
HeliumModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumModelTester(GemmaModelTester):
|
||||||
|
if is_torch_available():
|
||||||
|
config_class = HeliumConfig
|
||||||
|
model_class = HeliumModel
|
||||||
|
for_causal_lm_class = HeliumForCausalLM
|
||||||
|
for_sequence_class = HeliumForSequenceClassification
|
||||||
|
for_token_class = HeliumForTokenClassification
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class HeliumModelTest(GemmaModelTest, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(HeliumModel, HeliumForCausalLM, HeliumForSequenceClassification, HeliumForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (HeliumForCausalLM,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": HeliumModel,
|
||||||
|
"text-classification": HeliumForSequenceClassification,
|
||||||
|
"token-classification": HeliumForTokenClassification,
|
||||||
|
"text-generation": HeliumForCausalLM,
|
||||||
|
"zero-shot": HeliumForSequenceClassification,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
_is_stateful = True
|
||||||
|
model_split_percents = [0.5, 0.6]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = HeliumModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=HeliumConfig, hidden_size=37)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
# @require_torch_gpu
|
||||||
|
class HeliumIntegrationTest(unittest.TestCase):
|
||||||
|
input_text = ["Hello, today is a great day to"]
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
|
def test_model_2b(self):
|
||||||
|
model_id = "kyutai/helium-1-preview"
|
||||||
|
EXPECTED_TEXTS = [
|
||||||
|
"Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"
|
||||||
|
]
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, revision="refs/pr/1"
|
||||||
|
).to(torch_device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision="refs/pr/1")
|
||||||
|
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
Reference in New Issue
Block a user