Add mistral common support (#38906)
* wip: correct docstrings * Add mistral-common support. * quality * wip: add requested methods * wip: fix tests * wip: add internally some methods not being supported in mistral-common * wip * wip: add opencv dependency and update test list * wip: add mistral-common to testing dependencies * wip: revert some test changes * wip: ci * wip: ci * clean * check * check * check * wip: add hf image format to apply_chat_template and return pixel_values * wip: make mistral-common non-installed safe * wip: clean zip * fix: from_pretrained * fix: path and base64 * fix: path and import root * wip: add docs * clean * clean * revert --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from transformers.testing_utils import HfDoctestModule, HfDocTestParser
|
||||
|
||||
NOT_DEVICE_TESTS = {
|
||||
"test_tokenization",
|
||||
"test_tokenization_mistral_common",
|
||||
"test_processor",
|
||||
"test_processing",
|
||||
"test_beam_constraints",
|
||||
|
||||
@@ -139,6 +139,10 @@ Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/bl
|
||||
|
||||
[[autodoc]] MistralConfig
|
||||
|
||||
## MistralCommonTokenizer
|
||||
|
||||
[[autodoc]] MistralCommonTokenizer
|
||||
|
||||
## MistralModel
|
||||
|
||||
[[autodoc]] MistralModel
|
||||
|
||||
@@ -227,6 +227,10 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati
|
||||
|
||||
[[autodoc]] Mistral3Config
|
||||
|
||||
## MistralCommonTokenizer
|
||||
|
||||
[[autodoc]] MistralCommonTokenizer
|
||||
|
||||
## Mistral3Model
|
||||
|
||||
[[autodoc]] Mistral3Model
|
||||
|
||||
@@ -197,6 +197,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
|
||||
[[autodoc]] MixtralConfig
|
||||
|
||||
## MistralCommonTokenizer
|
||||
|
||||
[[autodoc]] MistralCommonTokenizer
|
||||
|
||||
## MixtralModel
|
||||
|
||||
[[autodoc]] MixtralModel
|
||||
|
||||
@@ -86,6 +86,10 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
|
||||
|
||||
[[autodoc]] PixtralVisionConfig
|
||||
|
||||
## MistralCommonTokenizer
|
||||
|
||||
[[autodoc]] MistralCommonTokenizer
|
||||
|
||||
## PixtralVisionModel
|
||||
|
||||
[[autodoc]] PixtralVisionModel
|
||||
|
||||
4
setup.py
4
setup.py
@@ -204,6 +204,7 @@ _deps = [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp",
|
||||
"opentelemetry-sdk",
|
||||
"mistral-common[opencv]>=1.6.3",
|
||||
]
|
||||
|
||||
|
||||
@@ -334,6 +335,7 @@ extras["video"] = deps_list("av")
|
||||
extras["num2words"] = deps_list("num2words")
|
||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||
extras["tiktoken"] = deps_list("tiktoken", "blobfile")
|
||||
extras["mistral-common"] = deps_list("mistral-common[opencv]")
|
||||
extras["testing"] = (
|
||||
deps_list(
|
||||
"pytest",
|
||||
@@ -363,6 +365,7 @@ extras["testing"] = (
|
||||
)
|
||||
+ extras["retrieval"]
|
||||
+ extras["modelcreation"]
|
||||
+ extras["mistral-common"]
|
||||
)
|
||||
|
||||
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
|
||||
@@ -384,6 +387,7 @@ extras["all"] = (
|
||||
+ extras["accelerate"]
|
||||
+ extras["video"]
|
||||
+ extras["num2words"]
|
||||
+ extras["mistral-common"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from .utils import (
|
||||
is_g2p_en_available,
|
||||
is_keras_nlp_available,
|
||||
is_librosa_available,
|
||||
is_mistral_common_available,
|
||||
is_pretty_midi_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
@@ -310,6 +311,18 @@ else:
|
||||
"convert_slow_tokenizer",
|
||||
]
|
||||
|
||||
try:
|
||||
if not (is_mistral_common_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_mistral_common_objects
|
||||
|
||||
_import_structure["utils.dummy_mistral_common_objects"] = [
|
||||
name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"]
|
||||
|
||||
# Vision-specific objects
|
||||
try:
|
||||
if not is_vision_available():
|
||||
|
||||
@@ -106,4 +106,5 @@ deps = {
|
||||
"opentelemetry-api": "opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp",
|
||||
"opentelemetry-sdk": "opentelemetry-sdk",
|
||||
"mistral-common[opencv]": "mistral-common[opencv]>=1.6.3",
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from transformers.utils.import_utils import is_mistral_common_available
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
||||
@@ -387,15 +389,19 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
(
|
||||
"mistral",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
"MistralCommonTokenizer"
|
||||
if is_mistral_common_available()
|
||||
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mixtral",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
"MistralCommonTokenizer"
|
||||
if is_mistral_common_available()
|
||||
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
||||
),
|
||||
),
|
||||
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@@ -490,7 +496,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"pixtral",
|
||||
(
|
||||
None,
|
||||
"MistralCommonTokenizer"
|
||||
if is_mistral_common_available()
|
||||
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
),
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@@ -721,8 +735,10 @@ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
|
||||
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
|
||||
if class_name in tokenizers:
|
||||
module_name = model_type_to_module_name(module_name)
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
if module_name in ["mistral", "mixtral"] and class_name == "MistralCommonTokenizer":
|
||||
module = importlib.import_module(".tokenization_mistral_common", "transformers")
|
||||
else:
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
|
||||
@@ -108,6 +108,7 @@ from .utils import (
|
||||
is_librosa_available,
|
||||
is_liger_kernel_available,
|
||||
is_lomo_available,
|
||||
is_mistral_common_available,
|
||||
is_natten_available,
|
||||
is_nltk_available,
|
||||
is_onnx_available,
|
||||
@@ -1526,6 +1527,13 @@ def require_speech(test_case):
|
||||
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
|
||||
|
||||
|
||||
def require_mistral_common(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
|
||||
"""
|
||||
return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
|
||||
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
||||
|
||||
1830
src/transformers/tokenization_mistral_common.py
Normal file
1830
src/transformers/tokenization_mistral_common.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -182,6 +182,7 @@ from .import_utils import (
|
||||
is_liger_kernel_available,
|
||||
is_lomo_available,
|
||||
is_matplotlib_available,
|
||||
is_mistral_common_available,
|
||||
is_mlx_available,
|
||||
is_natten_available,
|
||||
is_ninja_available,
|
||||
|
||||
9
src/transformers/utils/dummy_mistral_common_objects.py
Normal file
9
src/transformers/utils/dummy_mistral_common_objects.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class MistralCommonTokenizer(metaclass=DummyObject):
|
||||
_backends = ["mistral-common"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["mistral-common"])
|
||||
@@ -227,6 +227,7 @@ _spqr_available = _is_package_available("spqr_quant")
|
||||
_rich_available = _is_package_available("rich")
|
||||
_kernels_available = _is_package_available("kernels")
|
||||
_matplotlib_available = _is_package_available("matplotlib")
|
||||
_mistral_common_available = _is_package_available("mistral_common")
|
||||
|
||||
_torch_version = "N/A"
|
||||
_torch_available = False
|
||||
@@ -1575,6 +1576,10 @@ def is_matplotlib_available():
|
||||
return _matplotlib_available
|
||||
|
||||
|
||||
def is_mistral_common_available():
|
||||
return _mistral_common_available
|
||||
|
||||
|
||||
def check_torch_load_is_safe():
|
||||
if not is_torch_greater_or_equal("2.6"):
|
||||
raise ValueError(
|
||||
@@ -1979,6 +1984,11 @@ RICH_IMPORT_ERROR = """
|
||||
rich`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
MISTRAL_COMMON_IMPORT_ERROR = """
|
||||
{0} requires the mistral-common library but it was not found in your environment. You can install it with pip: `pip install mistral-common`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("av", (is_av_available, AV_IMPORT_ERROR)),
|
||||
@@ -2031,6 +2041,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
|
||||
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
|
||||
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
|
||||
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
1655
tests/test_tokenization_mistral_common.py
Normal file
1655
tests/test_tokenization_mistral_common.py
Normal file
File diff suppressed because one or more lines are too long
@@ -1164,7 +1164,7 @@ def parse_commit_message(commit_message: str) -> dict[str, bool]:
|
||||
JOB_TO_TEST_FILE = {
|
||||
"tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
|
||||
"tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
|
||||
"tests_tokenization": r"tests/models/.*/test_tokenization.*",
|
||||
"tests_tokenization": r"tests/(?:models/.*/test_tokenization.*|test_tokenization_mistral_common\.py)",
|
||||
"tests_processors": r"tests/models/.*/test_(?!(?:modeling_|tokenization_)).*", # takes feature extractors, image processors, processors
|
||||
"examples_torch": r"examples/pytorch/.*test_.*",
|
||||
"tests_exotic_models": r"tests/models/.*(?=layoutlmv|nat|deta|udop|nougat).*",
|
||||
|
||||
Reference in New Issue
Block a user