Compare commits

...

10 Commits

Author SHA1 Message Date
Amy Roberts
345b9b1a6a Release: v4.37.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2024-01-28 16:19:29 +00:00
amyeroberts
ea5ca81cf9 [Siglip] protect from imports if sentencepiece not installed (#28737)
[Siglip] protect from imports if sentencepiece not installed
2024-01-28 16:19:29 +00:00
Yih-Dar
711bed1f43 Fix weights_only (#28725)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2024-01-26 18:00:24 +00:00
Lysandre Debut
56ee444d2a Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599)
* Initial commit

* Requirements & tests

* Tests

* Tests

* Rogue import

* Rogue torch import

* Cleanup

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* bfloat16 management

* Sanchit's comments

* Import shield

* apply suggestions from code review

* correct bf16

* rebase

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
2024-01-26 18:00:20 +00:00
Yih-Dar
8d8fdb5202 Don't fail when LocalEntryNotFoundError during processor_config.json loading (#28709)
* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2024-01-26 17:57:40 +00:00
jeffhataws
e3934198a3 Use save_safetensor to disable safe serialization for XLA (#28669)
* Use save_safetensor to disable safe serialization for XLA

https://github.com/huggingface/transformers/issues/28438

* Style fixup
2024-01-26 17:57:28 +00:00
Zach Mueller
3001543b94 Fix windows err with checkpoint race conditions (#28637)
Fix windows err
2024-01-26 17:57:16 +00:00
amyeroberts
b94f5fdd7e [SigLIP] Only import tokenizer if sentencepiece available (#28636)
Only import class if sp available
2024-01-26 17:57:02 +00:00
Amy Roberts
d02d006cf3 Release: v4.37.1
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2024-01-24 15:24:07 +00:00
amyeroberts
b102ab26c7 Add back in generation types (#28681) 2024-01-24 15:22:28 +00:00
15 changed files with 170 additions and 103 deletions

View File

@@ -158,7 +158,7 @@ _deps = [
"ruff==0.1.5",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.3.1",
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92",
@@ -428,7 +428,7 @@ install_requires = [
setup(
name="transformers",
version="4.37.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.37.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.37.0"
__version__ = "4.37.2"
from typing import TYPE_CHECKING
@@ -772,7 +772,6 @@ _import_structure = {
"SiglipConfig",
"SiglipProcessor",
"SiglipTextConfig",
"SiglipTokenizer",
"SiglipVisionConfig",
],
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
@@ -1124,6 +1123,7 @@ else:
_import_structure["models.reformer"].append("ReformerTokenizer")
_import_structure["models.rembert"].append("RemBertTokenizer")
_import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizer")
_import_structure["models.siglip"].append("SiglipTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
_import_structure["models.speecht5"].append("SpeechT5Tokenizer")
_import_structure["models.t5"].append("T5Tokenizer")
@@ -5494,7 +5494,6 @@ if TYPE_CHECKING:
SiglipConfig,
SiglipProcessor,
SiglipTextConfig,
SiglipTokenizer,
SiglipVisionConfig,
)
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
@@ -5843,6 +5842,7 @@ if TYPE_CHECKING:
from .models.reformer import ReformerTokenizer
from .models.rembert import RemBertTokenizer
from .models.seamless_m4t import SeamlessM4TTokenizer
from .models.siglip import SiglipTokenizer
from .models.speech_to_text import Speech2TextTokenizer
from .models.speecht5 import SpeechT5Tokenizer
from .models.t5 import T5Tokenizer

View File

@@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict

View File

@@ -64,7 +64,7 @@ deps = {
"ruff": "ruff==0.1.5",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.3.1",
"safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",

View File

@@ -289,6 +289,11 @@ BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
# Typing shortcuts
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]

View File

@@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict
import transformers
from . import is_safetensors_available
from . import is_safetensors_available, is_torch_available
from .utils import logging
if is_torch_available():
import torch
if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
@@ -48,17 +51,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
@@ -66,12 +58,25 @@ def load_pytorch_checkpoint_in_flax_state_dict(
if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="pt") as f:
with safe_open(pt_path, framework="flax") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
try:
import torch # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else:
@@ -149,21 +154,17 @@ def rename_key_and_reshape_tensor(
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
try:
import torch # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
if from_bin:
for k, v in pt_state_dict.items():
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
model_prefix = flax_model.base_model_prefix
@@ -191,7 +192,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
@@ -229,7 +230,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = (
@@ -253,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix

View File

@@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
pt_state_dict.update(state_dict)

View File

@@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
loader = (
safe_load_file
if load_safe
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
@@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load(
checkpoint_file,
map_location=map_location,
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
**extra_args,
)
except Exception as e:

View File

@@ -379,7 +379,7 @@ else:
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
),
),
("siglip", ("SiglipTokenizer", None)),
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),

View File

@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_torch_available,
is_vision_available,
)
@@ -29,9 +30,17 @@ _import_structure = {
"SiglipVisionConfig",
],
"processing_siglip": ["SiglipProcessor"],
"tokenization_siglip": ["SiglipTokenizer"],
}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
@@ -63,7 +72,14 @@ if TYPE_CHECKING:
SiglipVisionConfig,
)
from .processing_siglip import SiglipProcessor
from .tokenization_siglip import SiglipTokenizer
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_siglip import SiglipTokenizer
try:
if not is_vision_available():

View File

@@ -1334,10 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir=cache_dir,
)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
weight_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
except EnvironmentError:

View File

@@ -317,6 +317,7 @@ class ProcessorMixin(PushToHubMixin):
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
@@ -331,6 +332,13 @@ class ProcessorMixin(PushToHubMixin):
f" directory containing a {PROCESSOR_NAME} file"
)
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
# However, for models added in the future, we won't get the expected error if this file is missing.
if resolved_processor_file is None:
return {}, kwargs
try:
# Load processor dict
with open(resolved_processor_file, "r", encoding="utf-8") as reader:
@@ -456,17 +464,7 @@ class ProcessorMixin(PushToHubMixin):
kwargs["token"] = token
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
# However, for models added in the future, we won't get the expected error if this file is missing.
try:
processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
except EnvironmentError as e:
if "does not appear to have a file named processor_config.json." in str(e):
processor_dict, kwargs = {}, kwargs
else:
raise
processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_args_and_dict(args, processor_dict, **kwargs)

View File

@@ -2088,6 +2088,7 @@ class Trainer:
)
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
@@ -2106,7 +2107,7 @@ class Trainer:
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
@@ -2123,7 +2124,7 @@ class Trainer:
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
@@ -2176,6 +2177,7 @@ class Trainer:
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
@@ -2195,7 +2197,7 @@ class Trainer:
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
state_dict["_smp_is_partial"] = False
@@ -2228,7 +2230,7 @@ class Trainer:
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
# If the model is on the GPU, it still works!
@@ -2415,9 +2417,11 @@ class Trainer:
os.rename(staging_output_dir, output_dir)
# Ensure rename completed in cases where os.rename is not atomic
fd = os.open(output_dir, os.O_RDONLY)
os.fsync(fd)
os.close(fd)
# And can only happen on non-windows based systems
if os.name != "nt":
fd = os.open(output_dir, os.O_RDONLY)
os.fsync(fd)
os.close(fd)
# Maybe delete some older checkpoints.
if self.args.should_save:
@@ -2905,13 +2909,19 @@ class Trainer:
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
model.save_pretrained(
output_dir,
is_main_process=self.args.should_save,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)

View File

@@ -184,6 +184,13 @@ class SeamlessM4TTokenizer(metaclass=DummyObject):
requires_backends(self, ["sentencepiece"])
class SiglipTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
class Speech2TextTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]

View File

@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai
from transformers.testing_utils import (
TOKEN,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
if is_flax_available():
@@ -42,6 +43,9 @@ if is_flax_available():
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
@require_flax
@is_staging_test
@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self):
@@ -265,7 +268,27 @@ class FlaxModelUtilsTest(unittest.TestCase):
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
import torch
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp:
model.save_pretrained(tmp)
flax_model = FlaxBertModel.from_pretrained(tmp)
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
@@ -284,39 +307,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
"""
@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the msgpack sharded weights
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
@require_safetensors
def test_safetensors_from_pt_bf16(self):
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
# and without torch.
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_from_pt_bf16(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
flat_params_1 = flatten_dict(new_model.params)
for value in flat_params_1.values():
self.assertEqual(value.dtype, "bfloat16")