Add mistral common support (#38906)

* wip: correct docstrings

* Add mistral-common support.

* quality

* wip: add requested methods

* wip: fix tests

* wip: add internally some methods not being supported in mistral-common

* wip

* wip: add opencv dependency and update test list

* wip: add mistral-common to testing dependencies

* wip: revert some test changes

* wip: ci

* wip: ci

* clean

* check

* check

* check

* wip: add hf image format to apply_chat_template and return pixel_values

* wip: make mistral-common non-installed safe

* wip: clean zip

* fix: from_pretrained

* fix: path and base64

* fix: path and import root

* wip: add docs

* clean

* clean

* revert

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Julien Denize
2025-07-11 18:26:58 +02:00
committed by GitHub
parent 665418dacc
commit 70e57e4710
16 changed files with 3573 additions and 8 deletions

View File

@@ -28,6 +28,7 @@ from transformers.testing_utils import HfDoctestModule, HfDocTestParser
NOT_DEVICE_TESTS = {
"test_tokenization",
"test_tokenization_mistral_common",
"test_processor",
"test_processing",
"test_beam_constraints",

View File

@@ -139,6 +139,10 @@ Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/bl
[[autodoc]] MistralConfig
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## MistralModel
[[autodoc]] MistralModel

View File

@@ -227,6 +227,10 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati
[[autodoc]] Mistral3Config
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## Mistral3Model
[[autodoc]] Mistral3Model

View File

@@ -197,6 +197,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MixtralConfig
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## MixtralModel
[[autodoc]] MixtralModel

View File

@@ -86,6 +86,10 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
[[autodoc]] PixtralVisionConfig
## MistralCommonTokenizer
[[autodoc]] MistralCommonTokenizer
## PixtralVisionModel
[[autodoc]] PixtralVisionModel

View File

@@ -204,6 +204,7 @@ _deps = [
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-sdk",
"mistral-common[opencv]>=1.6.3",
]
@@ -334,6 +335,7 @@ extras["video"] = deps_list("av")
extras["num2words"] = deps_list("num2words")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["tiktoken"] = deps_list("tiktoken", "blobfile")
extras["mistral-common"] = deps_list("mistral-common[opencv]")
extras["testing"] = (
deps_list(
"pytest",
@@ -363,6 +365,7 @@ extras["testing"] = (
)
+ extras["retrieval"]
+ extras["modelcreation"]
+ extras["mistral-common"]
)
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
@@ -384,6 +387,7 @@ extras["all"] = (
+ extras["accelerate"]
+ extras["video"]
+ extras["num2words"]
+ extras["mistral-common"]
)

View File

@@ -34,6 +34,7 @@ from .utils import (
is_g2p_en_available,
is_keras_nlp_available,
is_librosa_available,
is_mistral_common_available,
is_pretty_midi_available,
is_scipy_available,
is_sentencepiece_available,
@@ -310,6 +311,18 @@ else:
"convert_slow_tokenizer",
]
try:
if not (is_mistral_common_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_mistral_common_objects
_import_structure["utils.dummy_mistral_common_objects"] = [
name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
]
else:
_import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"]
# Vision-specific objects
try:
if not is_vision_available():

View File

@@ -106,4 +106,5 @@ deps = {
"opentelemetry-api": "opentelemetry-api",
"opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp",
"opentelemetry-sdk": "opentelemetry-sdk",
"mistral-common[opencv]": "mistral-common[opencv]>=1.6.3",
}

View File

@@ -21,6 +21,8 @@ import warnings
from collections import OrderedDict
from typing import Any, Optional, Union
from transformers.utils.import_utils import is_mistral_common_available
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -387,15 +389,19 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
(
"mistral",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
),
),
(
"mixtral",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
),
),
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
@@ -490,7 +496,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
(
"pixtral",
(
None,
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
),
),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
@@ -721,8 +735,10 @@ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
if module_name in ["mistral", "mixtral"] and class_name == "MistralCommonTokenizer":
module = importlib.import_module(".tokenization_mistral_common", "transformers")
else:
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name)
except AttributeError:

View File

@@ -108,6 +108,7 @@ from .utils import (
is_librosa_available,
is_liger_kernel_available,
is_lomo_available,
is_mistral_common_available,
is_natten_available,
is_nltk_available,
is_onnx_available,
@@ -1526,6 +1527,13 @@ def require_speech(test_case):
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
def require_mistral_common(test_case):
"""
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
"""
return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch, tf or jax is used)

File diff suppressed because it is too large Load Diff

View File

@@ -182,6 +182,7 @@ from .import_utils import (
is_liger_kernel_available,
is_lomo_available,
is_matplotlib_available,
is_mistral_common_available,
is_mlx_available,
is_natten_available,
is_ninja_available,

View File

@@ -0,0 +1,9 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class MistralCommonTokenizer(metaclass=DummyObject):
_backends = ["mistral-common"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["mistral-common"])

View File

@@ -227,6 +227,7 @@ _spqr_available = _is_package_available("spqr_quant")
_rich_available = _is_package_available("rich")
_kernels_available = _is_package_available("kernels")
_matplotlib_available = _is_package_available("matplotlib")
_mistral_common_available = _is_package_available("mistral_common")
_torch_version = "N/A"
_torch_available = False
@@ -1575,6 +1576,10 @@ def is_matplotlib_available():
return _matplotlib_available
def is_mistral_common_available():
return _mistral_common_available
def check_torch_load_is_safe():
if not is_torch_greater_or_equal("2.6"):
raise ValueError(
@@ -1979,6 +1984,11 @@ RICH_IMPORT_ERROR = """
rich`. Please note that you may need to restart your runtime after installation.
"""
MISTRAL_COMMON_IMPORT_ERROR = """
{0} requires the mistral-common library but it was not found in your environment. You can install it with pip: `pip install mistral-common`. Please note that you may need to restart your runtime after installation.
"""
BACKENDS_MAPPING = OrderedDict(
[
("av", (is_av_available, AV_IMPORT_ERROR)),
@@ -2031,6 +2041,7 @@ BACKENDS_MAPPING = OrderedDict(
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
]
)

File diff suppressed because one or more lines are too long

View File

@@ -1164,7 +1164,7 @@ def parse_commit_message(commit_message: str) -> dict[str, bool]:
JOB_TO_TEST_FILE = {
"tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
"tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
"tests_tokenization": r"tests/models/.*/test_tokenization.*",
"tests_tokenization": r"tests/(?:models/.*/test_tokenization.*|test_tokenization_mistral_common\.py)",
"tests_processors": r"tests/models/.*/test_(?!(?:modeling_|tokenization_)).*", # takes feature extractors, image processors, processors
"examples_torch": r"examples/pytorch/.*test_.*",
"tests_exotic_models": r"tests/models/.*(?=layoutlmv|nat|deta|udop|nougat).*",