Fix scheduled tests for SpeechEncoderDecoderModel (#13422)
* Add inputs to pretrained tests * Make style
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from .test_modeling_speech_to_text import Speech2TextModelTester
|
||||
from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
|
||||
from .test_modeling_wav2vec2 import Wav2Vec2ModelTester
|
||||
@@ -50,7 +50,7 @@ class EncoderDecoderMixin:
|
||||
def prepare_config_and_inputs(self):
|
||||
pass
|
||||
|
||||
def get_pretrained_model(self):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
pass
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_configs(
|
||||
@@ -350,17 +350,11 @@ class EncoderDecoderMixin:
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
model_2.to(torch_device)
|
||||
input_name, inputs = self.get_inputs()
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model_2(
|
||||
**{input_name: inputs},
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
outputs = model_2(**inputs)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
@@ -369,11 +363,7 @@ class EncoderDecoderMixin:
|
||||
model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||
model_1.to(torch_device)
|
||||
|
||||
after_outputs = model_1(
|
||||
**{input_name: inputs},
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
after_outputs = model_1(**inputs)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
@@ -382,10 +372,23 @@ class EncoderDecoderMixin:
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/wav2vec2-base-960h", "bert-base-cased"
|
||||
)
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], model.encoder.config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"input_values": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = Wav2Vec2Model(config).eval()
|
||||
@@ -433,10 +436,23 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/s2t-small-librispeech-asr", "bert-base-cased"
|
||||
)
|
||||
batch_size = 13
|
||||
input_features = floats_tensor([batch_size, 7, 80], model.encoder.config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 7])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"input_features": input_features,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = Speech2TextEncoder(config).eval()
|
||||
@@ -489,6 +505,10 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def test_save_and_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
# all published pretrained models are Speech2TextModel != Speech2TextEncoder
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
||||
@@ -524,5 +544,6 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
|
||||
# there are no published pretrained Speech2Text2ForCausalLM for now
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user