Add TF<>PT and Flax<>PT everywhere (#14047)
* up * up * up * up * up * up * up * add clip * fix clip PyTorch * fix clip PyTorch * up * up * up * up * up * up * up
This commit is contained in:
committed by
GitHub
parent
8560b55b5e
commit
0c3174c758
@@ -25,10 +25,13 @@ import unittest
|
||||
import warnings
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
@@ -36,6 +39,8 @@ from transformers.testing_utils import (
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
is_pt_flax_cross_test,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_torch,
|
||||
require_torch_multi_gpu,
|
||||
@@ -45,7 +50,6 @@ from transformers.testing_utils import (
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -70,6 +74,13 @@ if is_torch_available():
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
@@ -1417,6 +1428,241 @@ class ModelTesterMixin:
|
||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||
)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import transformers
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
|
||||
if not hasattr(transformers, tf_model_class_name):
|
||||
# transformers does not have TF version yet
|
||||
return
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
config.output_hidden_states = True
|
||||
|
||||
tf_model = tf_model_class(config)
|
||||
pt_model = model_class(config)
|
||||
|
||||
# make sure only tf inputs are forward that actually exist in function args
|
||||
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
|
||||
|
||||
# remove all head masks
|
||||
tf_input_keys.discard("head_mask")
|
||||
tf_input_keys.discard("cross_attn_head_mask")
|
||||
tf_input_keys.discard("decoder_head_mask")
|
||||
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs)
|
||||
tfo = tf_model(tf_inputs_dict, training=False)
|
||||
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
|
||||
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||
|
||||
pt_hidden_states[tf_nans] = 0
|
||||
tf_hidden_states[tf_nans] = 0
|
||||
pt_hidden_states[pt_nans] = 0
|
||||
tf_hidden_states[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tensor = np.array(tensor, dtype=bool)
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs)
|
||||
|
||||
tfo = tf_model(tf_inputs_dict)
|
||||
tfo = tfo[0].numpy()
|
||||
pto = pto[0].numpy()
|
||||
tf_nans = np.copy(np.isnan(tfo))
|
||||
pt_nans = np.copy(np.isnan(pto))
|
||||
|
||||
pto[tf_nans] = 0
|
||||
tfo[tf_nans] = 0
|
||||
pto[pt_nans] = 0
|
||||
tfo[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tfo - pto))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
return
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load corresponding PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
# no flax model exists for this class
|
||||
return
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user