Add ModernBERT to Transformers (#35158)

* initial cut of modernbert for transformers

* small bug fixes

* fixes

* Update import

* Use compiled mlp->mlp_norm to match research implementation

* Propagate changes in modular to modeling

* Replace duplicate attn_out_dropout in favor of attention_dropout

cc @warner-benjamin let me know if the two should remain separate!

* Update BOS to CLS and EOS to SEP

Please confirm @warner-benjamin

* Set default classifier bias to False, matching research repo

* Update tie_word_embeddings description

* Fix _init_weights for ForMaskedLM

* Match base_model_prefix

* Add compiled_head to match research repo outputs

* Fix imports for ModernBertForMaskedLM

* Just use "gelu" default outright for classifier

* Fix config name typo: initalizer -> initializer

* Remove some unused parameters in docstring. Still lots to edit there!

* Compile the embeddings forward

Not having this resulted in very slight differences - so small it wasn't even noticed for the base model, only for the large model.

But the tiny difference for large propagated at the embedding layer through the rest of the model, leading to notable differences of ~0.0084 average per value, up to 0.2343 for the worst case.

* Add drafts for ForSequenceClassification/ForTokenClassification

* Add initial SDPA support (not exactly equivalent to FA2 yet!)

During testing, FA2 and SDPA still differ by about 0.0098 per value in the token embeddings. It still predicts the correct mask fills, but I'd like to get it fully 1-1 if possible.

* Only use attention dropout if training

* Add initial eager attention support (also not equivalent to FA2 yet!)

Frustratingly, I also can't get eager to be equivalent to FA2 (or sdpa), but it does get really close, i.e. avg ~0.010 difference per value.

Especially if I use fp32 for both FA2&eager, avg ~0.0029 difference per value

The fill-mask results are good with eager.

* Add initial tests, output_attentions, output_hidden_states, prune_heads

Tests are based on BERT, not all tests pass yet: 23 failed, 79 passed, 100 skipped

* Remove kwargs from ModernBertForMaskedLM

Disable sparse_prediction by default to match the normal HF, can be enabled via config

* Remove/adjust/skip improper tests; warn if padding but no attn mask

* Run formatting etc.

* Run python utils/custom_init_isort.py

* FlexAttention with unpadded sequences(matches FA2 within bf16 numerics)

* Reformat init_weights based on review

* self -> module in attention forwards

* Remove if config.tie_word_embeddings

* Reformat output projection on a different line

* Remove pruning

* Remove assert

* Call contiguous() to simplify paths

* Remove prune_qkv_linear_layer

* Format code

* Keep as kwargs, only use if needed

* Remove unused codepaths & related config options

* Remove 3d attn_mask test; fix token classification tuple output

* Reorder: attention_mask above position_ids, fixes gradient checkpointing

* Fix usage if no FA2 or torch v2.5+

* Make torch.compile/triton optional

Should we rename 'compile'? It's a bit vague

* Separate pooling options into separate functions (cls, mean) - cls as default

* Simplify _pad_modernbert_output, remove unused labels path

* Update tied weights to remove decoder.weight, simplify decoder loading

* Adaptively set config.compile based on hf_device_map/device/resize, etc.

* Update ModernBertConfig docstring

* Satisfy some consistency checks, add unfinished docs

* Only set compile to False if there's more than 1 device

* Add docstrings for public ModernBert classes

* Dont replace docstring returns - ends up being duplicate

* Fix mistake in toctree

* Reformat toctree

* Patched FlexAttention, SDPA, Eager with Local Attention

* Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial

both to match the original performance, and to get the highest inference speed without requiring users to manually pick FA2

* Patch test edge case with Idefics3 not working with 'attn_implementation="sdpa"'

* Repad all_hidden_states as well

* rename config.compile to reference_compile

* disable flex_attention since it crashes

* Update modernbert.md

* Using dtype min to mask in eager

* Fully remove flex attention for now

It's only compatible with the nightly torch 2.6, so we'll leave it be for now. It's also slower than eager/sdpa.

Also, update compile -> reference_compile in one more case

* Call contiguous to allow for .view()

* Copyright 2020 -> 2024

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update/simplify __init__ structure

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove "... if dropout_prob > 0 else identity"

As dropout with 0.0 should be efficient like identity

* re-use existing pad/unpad functions instead of creating new ones

* remove flexattention method

* Compute attention_mask and local_attention_mask once in modeling

* Simplify sequence classification prediction heads, only CLS now

Users can make custom heads if they feel like it

Also removes the unnecessary pool parameter

* Simplify module.training in eager attn

* Also export ModernBertPreTrainedModel

* Update the documentation with links to finetuning scripts

* Explain local_attention_mask parameter in docstring

* Simplify _autoset_attn_implementation, rely on super()

* Keep "in" to initialize Prediction head

Doublechecked with Benjamin that it's correct/what we used for pretraining

* add back mean pooling

* Use the pooling head in TokenClassification

* update copyright

* Reset config._attn_implementation_internal on failure

* Allow optional attention_mask in ForMaskedLM head

* fix failing run_slow tests

* Add links to the paper

* Remove unpad_no_grad, always pad/unpad without gradients

* local_attention_mask -> sliding_window_mask

* Revert "Use the pooling head in TokenClassification"

This reverts commit 99c38badd1dbce01d7aef41095fbf2f5cce87279.

There was no real motivation, no info on whether having this bigger head does anything useful.

* Simplify pooling, 2 options via if-else

---------

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Co-authored-by: Said Taghadouini <taghadouinisaid@gmail.com>
Co-authored-by: Benjamin Clavié <ben@clavie.eu>
Co-authored-by: Antoine Chaffin <ant54600@hotmail.fr>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Benjamin Warner
2024-12-19 08:03:35 -05:00
committed by GitHub
parent 56ff1e92fd
commit 667ed5635e
19 changed files with 3568 additions and 2 deletions

View File

@@ -498,6 +498,8 @@
title: mLUKE title: mLUKE
- local: model_doc/mobilebert - local: model_doc/mobilebert
title: MobileBERT title: MobileBERT
- local: model_doc/modernbert
title: ModernBert
- local: model_doc/mpnet - local: model_doc/mpnet
title: MPNet title: MPNet
- local: model_doc/mpt - local: model_doc/mpt

View File

@@ -232,6 +232,7 @@ Flax), PyTorch, and/or TensorFlow.
| [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ | | [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ |
| [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ | | [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ |
| [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ | | [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ |
| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ |
| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ | | [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ |
| [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ | | [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ |
| [MPT](model_doc/mpt) | ✅ | ❌ | ❌ | | [MPT](model_doc/mpt) | ✅ | ❌ | ❌ |

View File

@@ -0,0 +1,91 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ModernBert
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/models?filter=modernbert">
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-modernbert-blueviolet">
</a>
<a href="https://arxiv.org/abs/2412.13663">
<img alt="Paper page" src="https://img.shields.io/badge/Paper%20page-2412.13663-green">
</a>
</div>
## Overview
The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as:
- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens.
- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences.
- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance.
- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers.
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing.
- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs.
- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data)
The abstract from the paper is the following:
*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.*
The original code can be found [here](https://github.com/answerdotai/modernbert).
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert.
<PipelineTag pipeline="sentence-similarity"/>
- A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎
- A script on how to [finetune for information retrieval with PyLate](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_pylate.py). 🌎
<PipelineTag pipeline="fill-mask"/>
- [Masked language modeling task guide](../tasks/masked_language_modeling)
## ModernBertConfig
[[autodoc]] ModernBertConfig
<frameworkcontent>
<pt>
## ModernBertModel
[[autodoc]] ModernBertModel
- forward
## ModernBertForMaskedLM
[[autodoc]] ModernBertForMaskedLM
- forward
## ModernBertForSequenceClassification
[[autodoc]] ModernBertForSequenceClassification
- forward
## ModernBertForTokenClassification
[[autodoc]] ModernBertForTokenClassification
- forward
</pt>
</frameworkcontent>

View File

@@ -74,6 +74,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert)
* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
@@ -265,6 +266,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration) * [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert)
* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)

View File

@@ -606,6 +606,7 @@ _import_structure = {
"models.mobilenet_v2": ["MobileNetV2Config"], "models.mobilenet_v2": ["MobileNetV2Config"],
"models.mobilevit": ["MobileViTConfig"], "models.mobilevit": ["MobileViTConfig"],
"models.mobilevitv2": ["MobileViTV2Config"], "models.mobilevitv2": ["MobileViTV2Config"],
"models.modernbert": ["ModernBertConfig"],
"models.moshi": [ "models.moshi": [
"MoshiConfig", "MoshiConfig",
"MoshiDepthConfig", "MoshiDepthConfig",
@@ -2869,6 +2870,15 @@ else:
"MobileViTV2PreTrainedModel", "MobileViTV2PreTrainedModel",
] ]
) )
_import_structure["models.modernbert"].extend(
[
"ModernBertForMaskedLM",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"ModernBertModel",
"ModernBertPreTrainedModel",
]
)
_import_structure["models.moshi"].extend( _import_structure["models.moshi"].extend(
[ [
"MoshiForCausalLM", "MoshiForCausalLM",
@@ -5565,6 +5575,7 @@ if TYPE_CHECKING:
from .models.mobilevitv2 import ( from .models.mobilevitv2 import (
MobileViTV2Config, MobileViTV2Config,
) )
from .models.modernbert import ModernBertConfig
from .models.moshi import ( from .models.moshi import (
MoshiConfig, MoshiConfig,
MoshiDepthConfig, MoshiDepthConfig,
@@ -7556,6 +7567,13 @@ if TYPE_CHECKING:
MobileViTV2Model, MobileViTV2Model,
MobileViTV2PreTrainedModel, MobileViTV2PreTrainedModel,
) )
from .models.modernbert import (
ModernBertForMaskedLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,
ModernBertPreTrainedModel,
)
from .models.moshi import ( from .models.moshi import (
MoshiForCausalLM, MoshiForCausalLM,
MoshiForConditionalGeneration, MoshiForConditionalGeneration,

View File

@@ -47,6 +47,22 @@ def ForCausalLMLoss(
return loss return loss
def ForMaskedLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
return loss
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
num_labels = config.num_labels num_labels = config.num_labels
if config.problem_type is None: if config.problem_type is None:
@@ -101,6 +117,7 @@ def ForTokenClassification(logits, labels, config, **kwargs):
LOSS_MAPPING = { LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss, "ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss, "ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification, "ForTokenClassification": ForTokenClassification,

View File

@@ -167,6 +167,7 @@ from . import (
mobilenet_v2, mobilenet_v2,
mobilevit, mobilevit,
mobilevitv2, mobilevitv2,
modernbert,
moshi, moshi,
mpnet, mpnet,
mpt, mpt,

View File

@@ -187,6 +187,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("mobilenet_v2", "MobileNetV2Config"), ("mobilenet_v2", "MobileNetV2Config"),
("mobilevit", "MobileViTConfig"), ("mobilevit", "MobileViTConfig"),
("mobilevitv2", "MobileViTV2Config"), ("mobilevitv2", "MobileViTV2Config"),
("modernbert", "ModernBertConfig"),
("moshi", "MoshiConfig"), ("moshi", "MoshiConfig"),
("mpnet", "MPNetConfig"), ("mpnet", "MPNetConfig"),
("mpt", "MptConfig"), ("mpt", "MptConfig"),
@@ -510,6 +511,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mobilenet_v2", "MobileNetV2"), ("mobilenet_v2", "MobileNetV2"),
("mobilevit", "MobileViT"), ("mobilevit", "MobileViT"),
("mobilevitv2", "MobileViTV2"), ("mobilevitv2", "MobileViTV2"),
("modernbert", "ModernBERT"),
("moshi", "Moshi"), ("moshi", "Moshi"),
("mpnet", "MPNet"), ("mpnet", "MPNet"),
("mpt", "MPT"), ("mpt", "MPT"),

View File

@@ -176,6 +176,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mobilenet_v2", "MobileNetV2Model"), ("mobilenet_v2", "MobileNetV2Model"),
("mobilevit", "MobileViTModel"), ("mobilevit", "MobileViTModel"),
("mobilevitv2", "MobileViTV2Model"), ("mobilevitv2", "MobileViTV2Model"),
("modernbert", "ModernBertModel"),
("moshi", "MoshiModel"), ("moshi", "MoshiModel"),
("mpnet", "MPNetModel"), ("mpnet", "MPNetModel"),
("mpt", "MptModel"), ("mpt", "MptModel"),
@@ -838,6 +839,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
("mega", "MegaForMaskedLM"), ("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"), ("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"), ("mobilebert", "MobileBertForMaskedLM"),
("modernbert", "ModernBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"), ("mpnet", "MPNetForMaskedLM"),
("mra", "MraForMaskedLM"), ("mra", "MraForMaskedLM"),
("mvp", "MvpForConditionalGeneration"), ("mvp", "MvpForConditionalGeneration"),
@@ -992,6 +994,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mistral", "MistralForSequenceClassification"), ("mistral", "MistralForSequenceClassification"),
("mixtral", "MixtralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"),
("modernbert", "ModernBertForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"), ("mpnet", "MPNetForSequenceClassification"),
("mpt", "MptForSequenceClassification"), ("mpt", "MptForSequenceClassification"),
("mra", "MraForSequenceClassification"), ("mra", "MraForSequenceClassification"),
@@ -1178,6 +1181,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mistral", "MistralForTokenClassification"), ("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"),
("modernbert", "ModernBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"), ("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"), ("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"), ("mra", "MraForTokenClassification"),

View File

@@ -313,6 +313,7 @@ else:
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),

View File

@@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_modernbert import *
from .modeling_modernbert import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,213 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_modernbert.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
from ...configuration_utils import PretrainedConfig
class ModernBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ModernBERT-base.
e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50368):
Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`ModernBertModel`]
hidden_size (`int`, *optional*, defaults to 768):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 1152):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 22):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer decoder.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
if not specified.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
norm_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the normalization layers.
pad_token_id (`int`, *optional*, defaults to 50283):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 50282):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 50281):
Beginning of stream token id.
cls_token_id (`int`, *optional*, defaults to 50281):
Classification token id.
sep_token_id (`int`, *optional*, defaults to 50282):
Separation token id.
global_rope_theta (`float`, *optional*, defaults to 160000.0):
The base period of the global RoPE embeddings.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
global_attn_every_n_layers (`int`, *optional*, defaults to 3):
The number of layers between global attention layers.
local_attention (`int`, *optional*, defaults to 128):
The window size for local attention.
local_rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the local RoPE embeddings.
embedding_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the MLP layers.
mlp_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the MLP layers.
decoder_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the decoder layers.
classifier_pooling (`str`, *optional*, defaults to `"cls"`):
The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
CLS token doesn't attend to all tokens on long sequences.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the classifier.
classifier_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the classifier.
classifier_activation (`str`, *optional*, defaults to `"gelu"`):
The activation function for the classifier.
deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
sparse_prediction (`bool`, *optional*, defaults to `False`):
Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
The index to ignore for the sparse prediction.
reference_compile (`bool`, *optional*):
Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
Examples:
```python
>>> from transformers import ModernBertModel, ModernBertConfig
>>> # Initializing a ModernBert style configuration
>>> configuration = ModernBertConfig()
>>> # Initializing a model from the modernbert-base style configuration
>>> model = ModernBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "modernbert"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50368,
hidden_size=768,
intermediate_size=1152,
num_hidden_layers=22,
num_attention_heads=12,
hidden_activation="gelu",
max_position_embeddings=8192,
initializer_range=0.02,
initializer_cutoff_factor=2.0,
norm_eps=1e-5,
norm_bias=False,
pad_token_id=50283,
eos_token_id=50282,
bos_token_id=50281,
cls_token_id=50281,
sep_token_id=50282,
global_rope_theta=160000.0,
attention_bias=False,
attention_dropout=0.0,
global_attn_every_n_layers=3,
local_attention=128,
local_rope_theta=10000.0,
embedding_dropout=0.0,
mlp_bias=False,
mlp_dropout=0.0,
decoder_bias=True,
classifier_pooling: Literal["cls", "mean"] = "cls",
classifier_dropout=0.0,
classifier_bias=False,
classifier_activation="gelu",
deterministic_flash_attn=False,
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
cls_token_id=cls_token_id,
sep_token_id=sep_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.initializer_cutoff_factor = initializer_cutoff_factor
self.norm_eps = norm_eps
self.norm_bias = norm_bias
self.global_rope_theta = global_rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_attention = local_attention
self.local_rope_theta = local_rope_theta
self.embedding_dropout = embedding_dropout
self.mlp_bias = mlp_bias
self.mlp_dropout = mlp_dropout
self.decoder_bias = decoder_bias
self.classifier_pooling = classifier_pooling
self.classifier_dropout = classifier_dropout
self.classifier_bias = classifier_bias
self.classifier_activation = classifier_activation
self.deterministic_flash_attn = deterministic_flash_attn
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
__all__ = ["ModernBertConfig"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -6418,6 +6418,41 @@ class MobileViTV2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ModernBertForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ModernBertForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ModernBertForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ModernBertModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ModernBertPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MoshiForCausalLM(metaclass=DummyObject): class MoshiForCausalLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@@ -192,7 +192,7 @@ _hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
_tiktoken_available = _is_package_available("tiktoken") _tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile") _blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel") _liger_kernel_available = _is_package_available("liger_kernel")
_triton_available = _is_package_available("triton")
_torch_version = "N/A" _torch_version = "N/A"
_torch_available = False _torch_available = False
@@ -1243,6 +1243,10 @@ def is_liger_kernel_available():
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0")
def is_triton_available():
return _triton_available
# docstyle-ignore # docstyle-ignore
AV_IMPORT_ERROR = """ AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with: {0} requires the PyAv library but it was not found in your environment. You can install it with:

View File

View File

@@ -0,0 +1,367 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import pytest
from transformers import ModernBertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
CaptureLogger,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
ModernBertForMaskedLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,
logging,
)
class ModernBertModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
pad_token_id=0,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_activation="gelu",
mlp_dropout=0.0,
attention_dropout=0.0,
embedding_dropout=0.0,
classifier_dropout=0.0,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_activation = hidden_activation
self.mlp_dropout = mlp_dropout
self.attention_dropout = attention_dropout
self.embedding_dropout = embedding_dropout
self.classifier_dropout = classifier_dropout
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
"""
Returns a tiny configuration by default.
"""
config = ModernBertConfig(
vocab_size=self.vocab_size,
pad_token_id=self.pad_token_id,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_activation=self.hidden_activation,
mlp_dropout=self.mlp_dropout,
attention_dropout=self.attention_dropout,
embedding_dropout=self.embedding_dropout,
classifier_dropout=self.classifier_dropout,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
if test := os.environ.get("PYTEST_CURRENT_TEST", False):
test_name = test.split(":")[-1].split(" ")[0]
# If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
# that compilation doesn't work. Users can then set compile=False when loading the model,
# much like here. We're testing whether it works once they've done that.
if test_name == "test_retain_grad_hidden_states_attentions":
config.reference_compile = False
# Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
# as the others don't support outputted attentions
if test_name in (
"test_attention_outputs",
"test_hidden_states_output",
"test_retain_grad_hidden_states_attentions",
):
config._attn_implementation = "eager"
return config
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = ModernBertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_lm(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = ModernBertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = ModernBertForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = ModernBertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torchscript = False
all_model_classes = (
(
ModernBertModel,
ModernBertForMaskedLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = ()
pipeline_model_mapping = (
{
"feature-extraction": ModernBertModel,
"fill-mask": ModernBertForMaskedLM,
"text-classification": ModernBertForSequenceClassification,
"token-classification": ModernBertForTokenClassification,
"zero-shot": ModernBertForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = False
test_head_masking = False
test_pruning = False
model_split_percents = [0.5, 0.8, 0.9]
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if inputs_dict.get("output_attentions", False):
inputs_dict["output_attentions"] = True
if return_labels:
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
inputs_dict["next_sentence_label"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
return inputs_dict
def setUp(self):
self.model_tester = ModernBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=ModernBertConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
# The classifier.weight from ModernBertForSequenceClassification and ModernBertForTokenClassification
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
if param.requires_grad and not (
name == "classifier.weight"
and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification]
):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
def test_inputs_embeds(self):
pass
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_warning_if_padding_and_no_attention_mask(self):
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.model_tester.prepare_config_and_inputs()
# Set pad tokens in the input_ids
input_ids[0, 0] = config.pad_token_id
# Check for warnings if the attention_mask is missing.
logger = logging.get_logger("transformers.modeling_utils")
# clear cache so we can test the warning is emitted (from `warning_once`).
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
model = ModernBertModel(config=config)
model.to(torch_device)
model.eval()
model(input_ids, attention_mask=None)
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
@unittest.skip("ModernBert doesn't use separate classes for SDPA, but a function instead.")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@slow
def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased"
model = ModernBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="ModernBert flash attention does not support right padding")
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
@require_torch
class ModernBertModelIntegrationTest(unittest.TestCase):
"""
These still need to be written, once public models are available.
"""

View File

@@ -3457,6 +3457,8 @@ class ModelTesterMixin:
"Data2VecAudioForSequenceClassification", "Data2VecAudioForSequenceClassification",
"UniSpeechForSequenceClassification", "UniSpeechForSequenceClassification",
"PvtForImageClassification", "PvtForImageClassification",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"TimmWrapperForImageClassification", "TimmWrapperForImageClassification",
] ]
special_param_names = [ special_param_names = [
@@ -4042,7 +4044,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) try:
model_sdpa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa"
)
except ValueError:
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained( model_eager = model_class.from_pretrained(