[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante
2025-02-19 15:13:27 +00:00
committed by GitHub
parent fa8cdccd91
commit 99adc74462
39 changed files with 33 additions and 3103 deletions

View File

@@ -18,16 +18,14 @@ import unittest
import numpy as np
from huggingface_hub import HfFolder, snapshot_download
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers import BertConfig, is_flax_available
from transformers.testing_utils import (
TOKEN,
CaptureLogger,
TemporaryHubRepo,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
@@ -42,9 +40,6 @@ if is_flax_available():
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
@require_flax
@is_staging_test
@@ -205,23 +200,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(model, new_model))
@require_flax
@require_torch
@is_pt_flax_cross_test
def test_safetensors_save_and_load_pt_to_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir)
# Check we have a model.safetensors file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = FlaxBertModel.from_pretrained(tmp_dir)
# Check models are equal
self.assertTrue(check_models_equal(model, new_model))
@require_safetensors
def test_safetensors_load_from_hub(self):
"""
@@ -248,58 +226,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
import torch
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp:
model.save_pretrained(tmp)
flax_model = FlaxBertModel.from_pretrained(tmp)
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
"""
@@ -328,19 +254,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(model, new_model))
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = FlaxBertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(hub_model, new_model))
@require_safetensors
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -370,27 +283,3 @@ class FlaxModelUtilsTest(unittest.TestCase):
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_from_pt_bf16(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
flat_params_1 = flatten_dict(new_model.params)
for value in flat_params_1.values():
self.assertEqual(value.dtype, "bfloat16")