[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante
2025-02-19 15:13:27 +00:00
committed by GitHub
parent fa8cdccd91
commit 99adc74462
39 changed files with 33 additions and 3103 deletions

View File

@@ -19,7 +19,7 @@ import unittest
import numpy as np
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
from transformers.testing_utils import require_flax, slow
from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
@@ -43,14 +43,8 @@ if is_flax_available():
SpeechEncoderDecoderConfig,
)
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import SpeechEncoderDecoderModel
@@ -406,68 +400,6 @@ class FlaxEncoderDecoderMixin:
for grad, grad_frozen in zip(grads, grads_frozen):
self.assertTrue((grad == grad_frozen).all())
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
@@ -504,46 +436,6 @@ class FlaxEncoderDecoderMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
decoder_config.hidden_size = config.hidden_size
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `add_adapter` works as expected
config.add_adapter = True
self.assertTrue(config.add_adapter)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
@@ -625,71 +517,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2gpt2_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("jsnfly/wav2vec2-large-xlsr-53-german-gpt2")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
"jsnfly/wav2vec2-large-xlsr-53-german-gpt2", from_pt=True
)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@require_flax
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
@@ -742,71 +569,6 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2bart_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
"patrickvonplaten/wav2vec2-2-bart-large", from_pt=True
)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@require_flax
class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
@@ -858,66 +620,3 @@ class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2bert_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)