Update PT Flax equivalence tests in PT test file (#16280)

* update PT/Flax equivalence tests on PT side

* overwrite check_outputs in BigBirdModelTest

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-03-24 14:45:30 +01:00
committed by GitHub
parent 41bfc1e262
commit f571dc20ac
2 changed files with 136 additions and 48 deletions

View File

@@ -1660,8 +1660,9 @@ class ModelTesterMixin:
# transformers does not have TF version yet
return
if self.has_attentions:
config.output_attentions = True
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
if k in inputs_dict:
@@ -1728,12 +1729,65 @@ class ModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
@@ -1741,15 +1795,9 @@ class ModelTesterMixin:
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
return
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
@@ -1759,29 +1807,41 @@ class ModelTesterMixin:
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@@ -1789,59 +1849,78 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load corresponding PyTorch class
pt_model = model_class(config).eval()
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()