Fix scheduled tests for SpeechEncoderDecoderModel (#13422)
* Add inputs to pretrained tests * Make style
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_modeling_bert import BertModelTester
|
from .test_modeling_bert import BertModelTester
|
||||||
from .test_modeling_common import ids_tensor
|
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from .test_modeling_speech_to_text import Speech2TextModelTester
|
from .test_modeling_speech_to_text import Speech2TextModelTester
|
||||||
from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
|
from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
|
||||||
from .test_modeling_wav2vec2 import Wav2Vec2ModelTester
|
from .test_modeling_wav2vec2 import Wav2Vec2ModelTester
|
||||||
@@ -50,7 +50,7 @@ class EncoderDecoderMixin:
|
|||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_pretrained_model(self):
|
def get_pretrained_model_and_inputs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def check_encoder_decoder_model_from_pretrained_configs(
|
def check_encoder_decoder_model_from_pretrained_configs(
|
||||||
@@ -350,17 +350,11 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_real_model_save_load_from_pretrained(self):
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
model_2 = self.get_pretrained_model()
|
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||||
model_2.to(torch_device)
|
model_2.to(torch_device)
|
||||||
input_name, inputs = self.get_inputs()
|
|
||||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
|
||||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model_2(
|
outputs = model_2(**inputs)
|
||||||
**{input_name: inputs},
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
|
||||||
out_2 = outputs[0].cpu().numpy()
|
out_2 = outputs[0].cpu().numpy()
|
||||||
out_2[np.isnan(out_2)] = 0
|
out_2[np.isnan(out_2)] = 0
|
||||||
|
|
||||||
@@ -369,11 +363,7 @@ class EncoderDecoderMixin:
|
|||||||
model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
|
model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
model_1.to(torch_device)
|
model_1.to(torch_device)
|
||||||
|
|
||||||
after_outputs = model_1(
|
after_outputs = model_1(**inputs)
|
||||||
**{input_name: inputs},
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
|
||||||
out_1 = after_outputs[0].cpu().numpy()
|
out_1 = after_outputs[0].cpu().numpy()
|
||||||
out_1[np.isnan(out_1)] = 0
|
out_1[np.isnan(out_1)] = 0
|
||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
@@ -382,10 +372,23 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||||
def get_pretrained_model(self):
|
def get_pretrained_model_and_inputs(self):
|
||||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
"facebook/wav2vec2-base-960h", "bert-base-cased"
|
"facebook/wav2vec2-base-960h", "bert-base-cased"
|
||||||
)
|
)
|
||||||
|
batch_size = 13
|
||||||
|
input_values = floats_tensor([batch_size, 512], model.encoder.config.vocab_size)
|
||||||
|
attention_mask = random_attention_mask([batch_size, 512])
|
||||||
|
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||||
|
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||||
|
inputs = {
|
||||||
|
"input_values": input_values,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, inputs
|
||||||
|
|
||||||
def get_encoder_decoder_model(self, config, decoder_config):
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
encoder_model = Wav2Vec2Model(config).eval()
|
encoder_model = Wav2Vec2Model(config).eval()
|
||||||
@@ -433,10 +436,23 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||||
def get_pretrained_model(self):
|
def get_pretrained_model_and_inputs(self):
|
||||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
"facebook/s2t-small-librispeech-asr", "bert-base-cased"
|
"facebook/s2t-small-librispeech-asr", "bert-base-cased"
|
||||||
)
|
)
|
||||||
|
batch_size = 13
|
||||||
|
input_features = floats_tensor([batch_size, 7, 80], model.encoder.config.vocab_size)
|
||||||
|
attention_mask = random_attention_mask([batch_size, 7])
|
||||||
|
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||||
|
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||||
|
inputs = {
|
||||||
|
"input_features": input_features,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, inputs
|
||||||
|
|
||||||
def get_encoder_decoder_model(self, config, decoder_config):
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
encoder_model = Speech2TextEncoder(config).eval()
|
encoder_model = Speech2TextEncoder(config).eval()
|
||||||
@@ -489,6 +505,10 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
def test_save_and_load_from_pretrained(self):
|
def test_save_and_load_from_pretrained(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# all published pretrained models are Speech2TextModel != Speech2TextEncoder
|
||||||
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
||||||
@@ -524,5 +544,6 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_pretrained_model(self):
|
# there are no published pretrained Speech2Text2ForCausalLM for now
|
||||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user