Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
345b9b1a6a | ||
|
|
ea5ca81cf9 | ||
|
|
711bed1f43 | ||
|
|
56ee444d2a | ||
|
|
8d8fdb5202 | ||
|
|
e3934198a3 | ||
|
|
3001543b94 | ||
|
|
b94f5fdd7e | ||
|
|
d02d006cf3 | ||
|
|
b102ab26c7 |
4
setup.py
4
setup.py
@@ -158,7 +158,7 @@ _deps = [
|
||||
"ruff==0.1.5",
|
||||
"sacrebleu>=1.4.12,<2.0.0",
|
||||
"sacremoses",
|
||||
"safetensors>=0.3.1",
|
||||
"safetensors>=0.4.1",
|
||||
"sagemaker>=2.31.0",
|
||||
"scikit-learn",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
@@ -428,7 +428,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.37.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.37.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.37.0"
|
||||
__version__ = "4.37.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -772,7 +772,6 @@ _import_structure = {
|
||||
"SiglipConfig",
|
||||
"SiglipProcessor",
|
||||
"SiglipTextConfig",
|
||||
"SiglipTokenizer",
|
||||
"SiglipVisionConfig",
|
||||
],
|
||||
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
|
||||
@@ -1124,6 +1123,7 @@ else:
|
||||
_import_structure["models.reformer"].append("ReformerTokenizer")
|
||||
_import_structure["models.rembert"].append("RemBertTokenizer")
|
||||
_import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizer")
|
||||
_import_structure["models.siglip"].append("SiglipTokenizer")
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
|
||||
_import_structure["models.speecht5"].append("SpeechT5Tokenizer")
|
||||
_import_structure["models.t5"].append("T5Tokenizer")
|
||||
@@ -5494,7 +5494,6 @@ if TYPE_CHECKING:
|
||||
SiglipConfig,
|
||||
SiglipProcessor,
|
||||
SiglipTextConfig,
|
||||
SiglipTokenizer,
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
|
||||
@@ -5843,6 +5842,7 @@ if TYPE_CHECKING:
|
||||
from .models.reformer import ReformerTokenizer
|
||||
from .models.rembert import RemBertTokenizer
|
||||
from .models.seamless_m4t import SeamlessM4TTokenizer
|
||||
from .models.siglip import SiglipTokenizer
|
||||
from .models.speech_to_text import Speech2TextTokenizer
|
||||
from .models.speecht5 import SpeechT5Tokenizer
|
||||
from .models.t5 import T5Tokenizer
|
||||
|
||||
@@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
|
||||
if compare_with_pt_model:
|
||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
state_dict = torch.load(
|
||||
pytorch_checkpoint_path,
|
||||
map_location="cpu",
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
pt_model = pt_model_class.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||
|
||||
@@ -64,7 +64,7 @@ deps = {
|
||||
"ruff": "ruff==0.1.5",
|
||||
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
||||
"sacremoses": "sacremoses",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"safetensors": "safetensors>=0.4.1",
|
||||
"sagemaker": "sagemaker>=2.31.0",
|
||||
"scikit-learn": "scikit-learn",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
|
||||
@@ -289,6 +289,11 @@ BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
|
||||
BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
|
||||
BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
|
||||
|
||||
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
|
||||
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
|
||||
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
|
||||
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
||||
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
|
||||
|
||||
# Typing shortcuts
|
||||
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
|
||||
|
||||
@@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
import transformers
|
||||
|
||||
from . import is_safetensors_available
|
||||
from . import is_safetensors_available, is_torch_available
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.flax import load_file as safe_load_file
|
||||
@@ -48,17 +51,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
|
||||
):
|
||||
"""Load pytorch checkpoints in a flax model"""
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
|
||||
if not is_sharded:
|
||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||
@@ -66,12 +58,25 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
|
||||
if pt_path.endswith(".safetensors"):
|
||||
pt_state_dict = {}
|
||||
with safe_open(pt_path, framework="pt") as f:
|
||||
with safe_open(pt_path, framework="flax") as f:
|
||||
for k in f.keys():
|
||||
pt_state_dict[k] = f.get_tensor(k)
|
||||
else:
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||
else:
|
||||
@@ -149,21 +154,17 @@ def rename_key_and_reshape_tensor(
|
||||
|
||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
# convert pytorch tensor to numpy
|
||||
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
|
||||
bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
|
||||
|
||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||
pt_state_dict = {
|
||||
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
|
||||
}
|
||||
|
||||
if from_bin:
|
||||
for k, v in pt_state_dict.items():
|
||||
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
||||
if v.dtype == bfloat16:
|
||||
v = v.float()
|
||||
pt_state_dict[k] = v.numpy()
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
|
||||
@@ -191,7 +192,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
# Need to change some parameters name to match Flax names
|
||||
for pt_key, pt_tensor in pt_state_dict.items():
|
||||
pt_tuple_key = tuple(pt_key.split("."))
|
||||
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
|
||||
is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
|
||||
|
||||
# remove base model prefix if necessary
|
||||
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
||||
@@ -229,7 +230,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
flax_state_dict[("params",) + flax_key] = (
|
||||
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
|
||||
)
|
||||
|
||||
else:
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[flax_key] = (
|
||||
@@ -253,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
flax_state_dict = {}
|
||||
for shard_file in shard_filenames:
|
||||
# load using msgpack utils
|
||||
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
|
||||
@@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
if pt_path.endswith(".safetensors"):
|
||||
state_dict = safe_load_file(pt_path)
|
||||
else:
|
||||
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||
|
||||
pt_state_dict.update(state_dict)
|
||||
|
||||
|
||||
@@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
loader = (
|
||||
safe_load_file
|
||||
if load_safe
|
||||
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
|
||||
)
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
@@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
and is_zipfile(checkpoint_file)
|
||||
):
|
||||
extra_args = {"mmap": True}
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location=map_location,
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
**extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -379,7 +379,7 @@ else:
|
||||
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("siglip", ("SiglipTokenizer", None)),
|
||||
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
@@ -29,9 +30,17 @@ _import_structure = {
|
||||
"SiglipVisionConfig",
|
||||
],
|
||||
"processing_siglip": ["SiglipProcessor"],
|
||||
"tokenization_siglip": ["SiglipTokenizer"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
|
||||
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -63,7 +72,14 @@ if TYPE_CHECKING:
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
from .processing_siglip import SiglipProcessor
|
||||
from .tokenization_siglip import SiglipTokenizer
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_siglip import SiglipTokenizer
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
|
||||
@@ -1334,10 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
state_dict = torch.load(
|
||||
weight_path,
|
||||
map_location="cpu",
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
|
||||
except EnvironmentError:
|
||||
|
||||
@@ -317,6 +317,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
@@ -331,6 +332,13 @@ class ProcessorMixin(PushToHubMixin):
|
||||
f" directory containing a {PROCESSOR_NAME} file"
|
||||
)
|
||||
|
||||
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
|
||||
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
|
||||
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
||||
# However, for models added in the future, we won't get the expected error if this file is missing.
|
||||
if resolved_processor_file is None:
|
||||
return {}, kwargs
|
||||
|
||||
try:
|
||||
# Load processor dict
|
||||
with open(resolved_processor_file, "r", encoding="utf-8") as reader:
|
||||
@@ -456,17 +464,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
kwargs["token"] = token
|
||||
|
||||
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
|
||||
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
|
||||
# However, for models added in the future, we won't get the expected error if this file is missing.
|
||||
try:
|
||||
processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
except EnvironmentError as e:
|
||||
if "does not appear to have a file named processor_config.json." in str(e):
|
||||
processor_dict, kwargs = {}, kwargs
|
||||
else:
|
||||
raise
|
||||
processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return cls.from_args_and_dict(args, processor_dict, **kwargs)
|
||||
|
||||
|
||||
@@ -2088,6 +2088,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
# If the model is on the GPU, it still works!
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
||||
@@ -2106,7 +2107,7 @@ class Trainer:
|
||||
state_dict = torch.load(
|
||||
weights_file,
|
||||
map_location="cpu",
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||
state_dict["_smp_is_partial"] = False
|
||||
@@ -2123,7 +2124,7 @@ class Trainer:
|
||||
state_dict = torch.load(
|
||||
weights_file,
|
||||
map_location="cpu",
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
@@ -2176,6 +2177,7 @@ class Trainer:
|
||||
or os.path.exists(best_safe_adapter_model_path)
|
||||
):
|
||||
has_been_loaded = True
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
||||
# If the 'user_content.pt' file exists, load with the new smp api.
|
||||
@@ -2195,7 +2197,7 @@ class Trainer:
|
||||
state_dict = torch.load(
|
||||
best_model_path,
|
||||
map_location="cpu",
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
|
||||
state_dict["_smp_is_partial"] = False
|
||||
@@ -2228,7 +2230,7 @@ class Trainer:
|
||||
state_dict = torch.load(
|
||||
best_model_path,
|
||||
map_location="cpu",
|
||||
weights_only=is_torch_greater_or_equal_than_1_13,
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
|
||||
# If the model is on the GPU, it still works!
|
||||
@@ -2415,9 +2417,11 @@ class Trainer:
|
||||
os.rename(staging_output_dir, output_dir)
|
||||
|
||||
# Ensure rename completed in cases where os.rename is not atomic
|
||||
fd = os.open(output_dir, os.O_RDONLY)
|
||||
os.fsync(fd)
|
||||
os.close(fd)
|
||||
# And can only happen on non-windows based systems
|
||||
if os.name != "nt":
|
||||
fd = os.open(output_dir, os.O_RDONLY)
|
||||
os.fsync(fd)
|
||||
os.close(fd)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save:
|
||||
@@ -2905,13 +2909,19 @@ class Trainer:
|
||||
is_main_process=self.args.should_save,
|
||||
state_dict=model.state_dict(),
|
||||
save_function=xm.save,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
state_dict = model.state_dict()
|
||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
is_main_process=self.args.should_save,
|
||||
save_function=xm.save,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
if self.tokenizer is not None and self.args.should_save:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
@@ -184,6 +184,13 @@ class SeamlessM4TTokenizer(metaclass=DummyObject):
|
||||
requires_backends(self, ["sentencepiece"])
|
||||
|
||||
|
||||
class SiglipTokenizer(metaclass=DummyObject):
|
||||
_backends = ["sentencepiece"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["sentencepiece"])
|
||||
|
||||
|
||||
class Speech2TextTokenizer(metaclass=DummyObject):
|
||||
_backends = ["sentencepiece"]
|
||||
|
||||
|
||||
@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -42,6 +43,9 @@ if is_flax_available():
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
@@ -265,7 +268,27 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
import torch
|
||||
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
model.save_pretrained(tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(tmp)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
@@ -284,39 +307,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
|
||||
"""
|
||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||
saved in the "pt" format if torch isn't installed.
|
||||
"""
|
||||
if is_torch_available():
|
||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||
# PyTorch shouldn't be installed for this to work correctly.
|
||||
return
|
||||
|
||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
|
||||
"""
|
||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||
saved in the "pt" format if torch isn't installed.
|
||||
"""
|
||||
if is_torch_available():
|
||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||
# PyTorch shouldn't be installed for this to work correctly.
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
|
||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
_ = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||
"""
|
||||
@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_flax_from_torch(self):
|
||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_from_pt_bf16(self):
|
||||
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
|
||||
# and without torch.
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_pt_bf16(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
flat_params_1 = flatten_dict(new_model.params)
|
||||
for value in flat_params_1.values():
|
||||
self.assertEqual(value.dtype, "bfloat16")
|
||||
|
||||
Reference in New Issue
Block a user