From 26700a951672f795e553cf0a357c1f221c6053c1 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Mon, 6 Sep 2021 15:55:13 +0300 Subject: [PATCH] Fix scheduled tests for `SpeechEncoderDecoderModel` (#13422) * Add inputs to pretrained tests * Make style --- tests/test_modeling_speech_encoder_decoder.py | 65 ++++++++++++------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/tests/test_modeling_speech_encoder_decoder.py b/tests/test_modeling_speech_encoder_decoder.py index e5562d4c68..6b0e5cf12f 100644 --- a/tests/test_modeling_speech_encoder_decoder.py +++ b/tests/test_modeling_speech_encoder_decoder.py @@ -21,7 +21,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_modeling_bert import BertModelTester -from .test_modeling_common import ids_tensor +from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from .test_modeling_speech_to_text import Speech2TextModelTester from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester from .test_modeling_wav2vec2 import Wav2Vec2ModelTester @@ -50,7 +50,7 @@ class EncoderDecoderMixin: def prepare_config_and_inputs(self): pass - def get_pretrained_model(self): + def get_pretrained_model_and_inputs(self): pass def check_encoder_decoder_model_from_pretrained_configs( @@ -350,17 +350,11 @@ class EncoderDecoderMixin: @slow def test_real_model_save_load_from_pretrained(self): - model_2 = self.get_pretrained_model() + model_2, inputs = self.get_pretrained_model_and_inputs() model_2.to(torch_device) - input_name, inputs = self.get_inputs() - decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size) - attention_mask = ids_tensor([13, 5], vocab_size=2) + with torch.no_grad(): - outputs = model_2( - **{input_name: inputs}, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - ) + outputs = model_2(**inputs) out_2 = outputs[0].cpu().numpy() out_2[np.isnan(out_2)] = 0 @@ -369,11 +363,7 @@ class EncoderDecoderMixin: model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname) model_1.to(torch_device) - after_outputs = model_1( - **{input_name: inputs}, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - ) + after_outputs = model_1(**inputs) out_1 = after_outputs[0].cpu().numpy() out_1[np.isnan(out_1)] = 0 max_diff = np.amax(np.abs(out_1 - out_2)) @@ -382,10 +372,23 @@ class EncoderDecoderMixin: @require_torch class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): - def get_pretrained_model(self): - return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + def get_pretrained_model_and_inputs(self): + model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( "facebook/wav2vec2-base-960h", "bert-base-cased" ) + batch_size = 13 + input_values = floats_tensor([batch_size, 512], model.encoder.config.vocab_size) + attention_mask = random_attention_mask([batch_size, 512]) + decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size) + decoder_attention_mask = random_attention_mask([batch_size, 4]) + inputs = { + "input_values": input_values, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + return model, inputs def get_encoder_decoder_model(self, config, decoder_config): encoder_model = Wav2Vec2Model(config).eval() @@ -433,10 +436,23 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): @require_torch class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase): - def get_pretrained_model(self): - return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + def get_pretrained_model_and_inputs(self): + model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( "facebook/s2t-small-librispeech-asr", "bert-base-cased" ) + batch_size = 13 + input_features = floats_tensor([batch_size, 7, 80], model.encoder.config.vocab_size) + attention_mask = random_attention_mask([batch_size, 7]) + decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size) + decoder_attention_mask = random_attention_mask([batch_size, 4]) + inputs = { + "input_features": input_features, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + return model, inputs def get_encoder_decoder_model(self, config, decoder_config): encoder_model = Speech2TextEncoder(config).eval() @@ -489,6 +505,10 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase): def test_save_and_load_from_pretrained(self): pass + # all published pretrained models are Speech2TextModel != Speech2TextEncoder + def test_real_model_save_load_from_pretrained(self): + pass + @require_torch class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase): @@ -524,5 +544,6 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase): "decoder_attention_mask": decoder_attention_mask, } - def get_pretrained_model(self): - return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large") + # there are no published pretrained Speech2Text2ForCausalLM for now + def test_real_model_save_load_from_pretrained(self): + pass