[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -22,11 +22,8 @@ import unittest
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
import transformers
|
||||
from transformers import CLIPSegConfig, CLIPSegProcessor, CLIPSegTextConfig, CLIPSegVisionConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flax_available,
|
||||
is_pt_flax_cross_test,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
@@ -57,15 +54,6 @@ if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
class CLIPSegVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -635,123 +623,6 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
text_config = CLIPSegTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
# overwrite from common since FlaxCLIPSegModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since FlaxCLIPSegModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load corresponding PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
self.skipTest(reason="Training test is skipped as the model was not trained")
|
||||
|
||||
Reference in New Issue
Block a user