[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -20,15 +20,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_flax, require_torch, require_vision, slow
|
||||
from transformers.utils import is_flax_available, is_vision_available
|
||||
|
||||
from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
|
||||
@@ -45,17 +38,8 @@ if is_flax_available():
|
||||
VisionTextDualEncoderConfig,
|
||||
VisionTextDualEncoderProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VisionTextDualEncoderModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@@ -154,68 +138,6 @@ class VisionTextDualEncoderMixin:
|
||||
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_model_from_pretrained_configs(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_model_from_pretrained_configs(**inputs_dict)
|
||||
@@ -232,17 +154,6 @@ class VisionTextDualEncoderMixin:
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_output_attention(**inputs_dict)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
|
||||
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
|
||||
@@ -20,8 +20,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bert.test_modeling_bert import BertModelTester
|
||||
@@ -44,12 +44,6 @@ if is_torch_available():
|
||||
ViTModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxVisionTextDualEncoderModel
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@@ -172,69 +166,6 @@ class VisionTextDualEncoderMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
||||
pt_inputs = inputs_dict
|
||||
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def test_vision_text_dual_encoder_model(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_dual_encoder_model(**inputs_dict)
|
||||
@@ -255,17 +186,6 @@ class VisionTextDualEncoderMixin:
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_output_attention(**inputs_dict)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
|
||||
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
@@ -429,10 +349,6 @@ class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
"text_choice_labels": choice_labels,
|
||||
}
|
||||
|
||||
@unittest.skip(reason="DeiT is not available in Flax")
|
||||
def test_pt_flax_equivalence(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user