[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -18,16 +18,14 @@ import unittest
|
||||
import numpy as np
|
||||
from huggingface_hub import HfFolder, snapshot_download
|
||||
|
||||
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
TemporaryHubRepo,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
|
||||
|
||||
@@ -42,9 +40,6 @@ if is_flax_available():
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
@@ -205,23 +200,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir)
|
||||
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
"""
|
||||
@@ -248,58 +226,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
import torch
|
||||
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
model.save_pretrained(tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(tmp)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||
"""
|
||||
@@ -328,19 +254,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_flax_from_torch(self):
|
||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(hub_model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -370,27 +283,3 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_pt_bf16(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
flat_params_1 = flatten_dict(new_model.params)
|
||||
for value in flat_params_1.values():
|
||||
self.assertEqual(value.dtype, "bfloat16")
|
||||
|
||||
@@ -24,7 +24,7 @@ from math import isnan
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_modeling_tf_common import ids_tensor
|
||||
|
||||
@@ -48,20 +48,6 @@ if is_tf_available():
|
||||
)
|
||||
from transformers.modeling_tf_utils import keras
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFCoreModelTesterMixin:
|
||||
|
||||
@@ -33,7 +33,6 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TemporaryHubRepo,
|
||||
_tf_gpu_memory_limit,
|
||||
is_staging_test,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
@@ -68,20 +67,6 @@ if is_tf_available():
|
||||
|
||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFModelUtilsTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user