[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -21,13 +21,10 @@ from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers import is_flax_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
|
||||
from transformers.testing_utils import CaptureLogger, require_flax
|
||||
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -47,17 +44,10 @@ if is_flax_available():
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
@@ -184,216 +174,6 @@ class FlaxModelTesterMixin:
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
Currently unused, but it could make debugging easier and faster.
|
||||
|
||||
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(fx_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
|
||||
)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in fx_keys])
|
||||
self.check_pt_flax_outputs(
|
||||
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(
|
||||
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
|
||||
)
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(fx_outputs),
|
||||
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
|
||||
)
|
||||
else:
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
|
||||
|
||||
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
|
||||
if isinstance(pt_output, DynamicCache):
|
||||
pt_output = pt_output.to_legacy_cache()
|
||||
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
|
||||
)
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
self.assertEqual(
|
||||
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(fx_outputs):
|
||||
fx_outputs = np.array([fx_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[fx_nans] = 0
|
||||
fx_outputs[fx_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(
|
||||
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
|
||||
f" {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
# It might be better to put this inside the for loop below (because we modify the config there).
|
||||
# But logically, it is fine.
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(
|
||||
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
|
||||
)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -474,92 +254,6 @@ class FlaxModelTesterMixin:
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = get_params(model.params)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -1119,14 +813,6 @@ class FlaxModelTesterMixin:
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_sharded_pt(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
|
||||
for key, ref_val in flatten_dict(ref_model.params).items():
|
||||
val = flatten_dict(model.params)[key]
|
||||
assert np.allclose(np.array(val), np.array(ref_val))
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user