Add mistral common support (#38906)

* wip: correct docstrings

* Add mistral-common support.

* quality

* wip: add requested methods

* wip: fix tests

* wip: add internally some methods not being supported in mistral-common

* wip

* wip: add opencv dependency and update test list

* wip: add mistral-common to testing dependencies

* wip: revert some test changes

* wip: ci

* wip: ci

* clean

* check

* check

* check

* wip: add hf image format to apply_chat_template and return pixel_values

* wip: make mistral-common non-installed safe

* wip: clean zip

* fix: from_pretrained

* fix: path and base64

* fix: path and import root

* wip: add docs

* clean

* clean

* revert

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Julien Denize
2025-07-11 18:26:58 +02:00
committed by GitHub
parent 665418dacc
commit 70e57e4710
16 changed files with 3573 additions and 8 deletions

View File

@@ -28,6 +28,7 @@ from transformers.testing_utils import HfDoctestModule, HfDocTestParser
NOT_DEVICE_TESTS = { NOT_DEVICE_TESTS = {
"test_tokenization", "test_tokenization",
"test_tokenization_mistral_common",
"test_processor", "test_processor",
"test_processing", "test_processing",
"test_beam_constraints", "test_beam_constraints",

View File

@@ -139,6 +139,10 @@ Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/bl
[[autodoc]] MistralConfig [[autodoc]] MistralConfig
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## MistralModel ## MistralModel
[[autodoc]] MistralModel [[autodoc]] MistralModel

View File

@@ -227,6 +227,10 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati
[[autodoc]] Mistral3Config [[autodoc]] Mistral3Config
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## Mistral3Model ## Mistral3Model
[[autodoc]] Mistral3Model [[autodoc]] Mistral3Model

View File

@@ -197,6 +197,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MixtralConfig [[autodoc]] MixtralConfig
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## MixtralModel ## MixtralModel
[[autodoc]] MixtralModel [[autodoc]] MixtralModel

View File

@@ -86,6 +86,10 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
[[autodoc]] PixtralVisionConfig [[autodoc]] PixtralVisionConfig
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## PixtralVisionModel ## PixtralVisionModel
[[autodoc]] PixtralVisionModel [[autodoc]] PixtralVisionModel

View File

@@ -204,6 +204,7 @@ _deps = [
"opentelemetry-api", "opentelemetry-api",
"opentelemetry-exporter-otlp", "opentelemetry-exporter-otlp",
"opentelemetry-sdk", "opentelemetry-sdk",
"mistral-common[opencv]>=1.6.3",
] ]
@@ -334,6 +335,7 @@ extras["video"] = deps_list("av")
extras["num2words"] = deps_list("num2words") extras["num2words"] = deps_list("num2words")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["tiktoken"] = deps_list("tiktoken", "blobfile") extras["tiktoken"] = deps_list("tiktoken", "blobfile")
extras["mistral-common"] = deps_list("mistral-common[opencv]")
extras["testing"] = ( extras["testing"] = (
deps_list( deps_list(
"pytest", "pytest",
@@ -363,6 +365,7 @@ extras["testing"] = (
) )
+ extras["retrieval"] + extras["retrieval"]
+ extras["modelcreation"] + extras["modelcreation"]
+ extras["mistral-common"]
) )
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"] extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
@@ -384,6 +387,7 @@ extras["all"] = (
+ extras["accelerate"] + extras["accelerate"]
+ extras["video"] + extras["video"]
+ extras["num2words"] + extras["num2words"]
+ extras["mistral-common"]
) )

View File

@@ -34,6 +34,7 @@ from .utils import (
is_g2p_en_available, is_g2p_en_available,
is_keras_nlp_available, is_keras_nlp_available,
is_librosa_available, is_librosa_available,
is_mistral_common_available,
is_pretty_midi_available, is_pretty_midi_available,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
@@ -310,6 +311,18 @@ else:
"convert_slow_tokenizer", "convert_slow_tokenizer",
] ]
try:
if not (is_mistral_common_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_mistral_common_objects
_import_structure["utils.dummy_mistral_common_objects"] = [
name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
]
else:
_import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"]
# Vision-specific objects # Vision-specific objects
try: try:
if not is_vision_available(): if not is_vision_available():

View File

@@ -106,4 +106,5 @@ deps = {
"opentelemetry-api": "opentelemetry-api", "opentelemetry-api": "opentelemetry-api",
"opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp", "opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp",
"opentelemetry-sdk": "opentelemetry-sdk", "opentelemetry-sdk": "opentelemetry-sdk",
"mistral-common[opencv]": "mistral-common[opencv]>=1.6.3",
} }

View File

@@ -21,6 +21,8 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Union from typing import Any, Optional, Union
from transformers.utils.import_utils import is_mistral_common_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -387,15 +389,19 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
( (
"mistral", "mistral",
( (
"LlamaTokenizer" if is_sentencepiece_available() else None, "MistralCommonTokenizer"
"LlamaTokenizerFast" if is_tokenizers_available() else None, if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
), ),
), ),
( (
"mixtral", "mixtral",
( (
"LlamaTokenizer" if is_sentencepiece_available() else None, "MistralCommonTokenizer"
"LlamaTokenizerFast" if is_tokenizers_available() else None, if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
), ),
), ),
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
@@ -490,7 +496,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)), ("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), (
"pixtral",
(
None,
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
),
),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)), ("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
@@ -721,7 +735,9 @@ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers: if class_name in tokenizers:
module_name = model_type_to_module_name(module_name) module_name = model_type_to_module_name(module_name)
if module_name in ["mistral", "mixtral"] and class_name == "MistralCommonTokenizer":
module = importlib.import_module(".tokenization_mistral_common", "transformers")
else:
module = importlib.import_module(f".{module_name}", "transformers.models") module = importlib.import_module(f".{module_name}", "transformers.models")
try: try:
return getattr(module, class_name) return getattr(module, class_name)

View File

@@ -108,6 +108,7 @@ from .utils import (
is_librosa_available, is_librosa_available,
is_liger_kernel_available, is_liger_kernel_available,
is_lomo_available, is_lomo_available,
is_mistral_common_available,
is_natten_available, is_natten_available,
is_nltk_available, is_nltk_available,
is_onnx_available, is_onnx_available,
@@ -1526,6 +1527,13 @@ def require_speech(test_case):
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case) return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
def require_mistral_common(test_case):
"""
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
"""
return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
def get_gpu_count(): def get_gpu_count():
""" """
Return the number of available gpus (regardless of whether torch, tf or jax is used) Return the number of available gpus (regardless of whether torch, tf or jax is used)

File diff suppressed because it is too large Load Diff

View File

@@ -182,6 +182,7 @@ from .import_utils import (
is_liger_kernel_available, is_liger_kernel_available,
is_lomo_available, is_lomo_available,
is_matplotlib_available, is_matplotlib_available,
is_mistral_common_available,
is_mlx_available, is_mlx_available,
is_natten_available, is_natten_available,
is_ninja_available, is_ninja_available,

View File

@@ -0,0 +1,9 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class MistralCommonTokenizer(metaclass=DummyObject):
_backends = ["mistral-common"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["mistral-common"])

View File

@@ -227,6 +227,7 @@ _spqr_available = _is_package_available("spqr_quant")
_rich_available = _is_package_available("rich") _rich_available = _is_package_available("rich")
_kernels_available = _is_package_available("kernels") _kernels_available = _is_package_available("kernels")
_matplotlib_available = _is_package_available("matplotlib") _matplotlib_available = _is_package_available("matplotlib")
_mistral_common_available = _is_package_available("mistral_common")
_torch_version = "N/A" _torch_version = "N/A"
_torch_available = False _torch_available = False
@@ -1575,6 +1576,10 @@ def is_matplotlib_available():
return _matplotlib_available return _matplotlib_available
def is_mistral_common_available():
return _mistral_common_available
def check_torch_load_is_safe(): def check_torch_load_is_safe():
if not is_torch_greater_or_equal("2.6"): if not is_torch_greater_or_equal("2.6"):
raise ValueError( raise ValueError(
@@ -1979,6 +1984,11 @@ RICH_IMPORT_ERROR = """
rich`. Please note that you may need to restart your runtime after installation. rich`. Please note that you may need to restart your runtime after installation.
""" """
MISTRAL_COMMON_IMPORT_ERROR = """
{0} requires the mistral-common library but it was not found in your environment. You can install it with pip: `pip install mistral-common`. Please note that you may need to restart your runtime after installation.
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("av", (is_av_available, AV_IMPORT_ERROR)), ("av", (is_av_available, AV_IMPORT_ERROR)),
@@ -2031,6 +2041,7 @@ BACKENDS_MAPPING = OrderedDict(
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)), ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)), ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
] ]
) )

File diff suppressed because one or more lines are too long

View File

@@ -1164,7 +1164,7 @@ def parse_commit_message(commit_message: str) -> dict[str, bool]:
JOB_TO_TEST_FILE = { JOB_TO_TEST_FILE = {
"tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*", "tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
"tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*", "tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
"tests_tokenization": r"tests/models/.*/test_tokenization.*", "tests_tokenization": r"tests/(?:models/.*/test_tokenization.*|test_tokenization_mistral_common\.py)",
"tests_processors": r"tests/models/.*/test_(?!(?:modeling_|tokenization_)).*", # takes feature extractors, image processors, processors "tests_processors": r"tests/models/.*/test_(?!(?:modeling_|tokenization_)).*", # takes feature extractors, image processors, processors
"examples_torch": r"examples/pytorch/.*test_.*", "examples_torch": r"examples/pytorch/.*test_.*",
"tests_exotic_models": r"tests/models/.*(?=layoutlmv|nat|deta|udop|nougat).*", "tests_exotic_models": r"tests/models/.*(?=layoutlmv|nat|deta|udop|nougat).*",