[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -32,7 +32,6 @@ from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
@@ -75,7 +74,6 @@ from transformers.models.auto.modeling_auto import (
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
@@ -100,14 +98,12 @@ from transformers.utils import (
|
||||
GENERATION_CONFIG_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_fx_available,
|
||||
is_torch_sdpa_available,
|
||||
)
|
||||
from transformers.utils.generic import ContextManagers, ModelOutput
|
||||
from transformers.utils.generic import ContextManagers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -126,19 +122,6 @@ if is_torch_available():
|
||||
from transformers.modeling_utils import load_state_dict, no_init_weights
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
pass
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from tests.utils.test_modeling_flax_utils import check_models_equal
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
|
||||
|
||||
@@ -2552,249 +2535,6 @@ class ModelTesterMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
Currently unused, but it could make debugging easier and faster.
|
||||
|
||||
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(fx_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
|
||||
)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in fx_keys])
|
||||
self.check_pt_flax_outputs(
|
||||
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(
|
||||
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
|
||||
)
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(fx_outputs),
|
||||
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
|
||||
)
|
||||
else:
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
|
||||
|
||||
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
|
||||
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
|
||||
)
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
self.assertEqual(
|
||||
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(fx_outputs):
|
||||
fx_outputs = np.array([fx_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[fx_nans] = 0
|
||||
fx_outputs[fx_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(
|
||||
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
|
||||
f" {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(
|
||||
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
|
||||
)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -4413,29 +4153,6 @@ class ModelTesterMixin:
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_flax_from_pt_safetensors(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning
|
||||
if not hasattr(transformers, flax_model_class_name):
|
||||
self.skipTest(reason="transformers does not have this model in Flax version yet")
|
||||
|
||||
flax_model_class = getattr(transformers, flax_model_class_name)
|
||||
|
||||
pt_model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||
flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user