Add SpeechEncoderDecoder & Speech2Text2 (#13186)
* fix_torch_device_generate_test * remove @ * up * correct some bugs * correct model * finish speech2text extension * up * up * up * up * Update utils/custom_init_isort.py * up * up * update with tokenizer * correct old tok * correct old tok * fix bug * up * up * add more tests * up * fix docs * up * fix some more tests * add better config * correct some more things " * fix tests * improve docs * Apply suggestions from code review * Apply suggestions from code review * final fixes * finalize * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * apply suggestions Lysandre and Sylvain * apply nicos suggestions * upload everything * finish Co-authored-by: Patrick von Platen <patrick@huggingface.co> Co-authored-by: your_github_username <your_github_email> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
9396b40433
commit
0b8c84e110
527
tests/test_modeling_speech_encoder_decoder.py
Normal file
527
tests/test_modeling_speech_encoder_decoder.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .test_modeling_speech_to_text import Speech2TextModelTester
|
||||
from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
|
||||
from .test_modeling_wav2vec2 import Wav2Vec2ModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
BertLMHeadModel,
|
||||
Speech2Text2ForCausalLM,
|
||||
SpeechEncoderDecoderConfig,
|
||||
SpeechEncoderDecoderModel,
|
||||
Wav2Vec2Model,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextEncoder
|
||||
|
||||
|
||||
@require_torch
|
||||
class EncoderDecoderMixin:
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
pass
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pass
|
||||
|
||||
def get_pretrained_model(self):
|
||||
pass
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_configs(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
return_dict,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||
enc_dec_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_save_and_load(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
enc_dec_model.save_pretrained(tmpdirname)
|
||||
enc_dec_model = SpeechEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_save_and_load_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
|
||||
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
|
||||
)
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels=None,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_values=input_values,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
inputs = input_values if input_features is None else input_features
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
seq_len = enc_dec_model.encoder._get_feat_extract_output_lengths(inputs.shape[1])
|
||||
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(
|
||||
self, config, decoder_config, input_values=None, input_features=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
inputs = input_values if input_features is None else input_features
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def test_encoder_decoder_model(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_return_dict(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
||||
|
||||
def test_save_and_load_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_and_load(**input_ids_dict)
|
||||
|
||||
def test_save_and_load_from_encoder_decoder_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_and_load_encoder_decoder_model(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
model_2.to(torch_device)
|
||||
input_name, inputs = self.get_inputs()
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||
with torch.no_grad():
|
||||
outputs = model_2(
|
||||
**{input_name: inputs},
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model_2.save_pretrained(tmp_dirname)
|
||||
model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||
model_1.to(torch_device)
|
||||
|
||||
after_outputs = model_1(
|
||||
**{input_name: inputs},
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/wav2vec2-base-960h", "bert-base-cased"
|
||||
)
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = Wav2Vec2Model(config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
bert_model_tester = BertModelTester(self)
|
||||
wav2vec2_model_tester = Wav2Vec2ModelTester(self)
|
||||
encoder_config_and_inputs = wav2vec2_model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
(
|
||||
config,
|
||||
input_values,
|
||||
input_mask,
|
||||
) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_token_type_ids,
|
||||
decoder_input_mask,
|
||||
decoder_sequence_labels,
|
||||
decoder_token_labels,
|
||||
decoder_choice_labels,
|
||||
encoder_attention_mask,
|
||||
_,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_values": input_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_sequence_labels": decoder_sequence_labels,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/s2t-small-librispeech-asr", "bert-base-cased"
|
||||
)
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = Speech2TextEncoder(config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
bert_model_tester = BertModelTester(self)
|
||||
speech2text_model_tester = Speech2TextModelTester(self)
|
||||
encoder_config_and_inputs = speech2text_model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
config, inputs = encoder_config_and_inputs
|
||||
input_features = inputs["input_features"]
|
||||
input_mask = inputs["attention_mask"]
|
||||
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_token_type_ids,
|
||||
decoder_input_mask,
|
||||
decoder_sequence_labels,
|
||||
decoder_token_labels,
|
||||
decoder_choice_labels,
|
||||
encoder_attention_mask,
|
||||
_,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_features": input_features,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_sequence_labels": decoder_sequence_labels,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
# can't save full model for now because Speech2TextModel != Speech2TextEncoder
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
pass
|
||||
|
||||
# can't save full model for now because Speech2TextModel != Speech2TextEncoder
|
||||
def test_save_and_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = Wav2Vec2Model(config)
|
||||
decoder_model = Speech2Text2ForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = Wav2Vec2ModelTester(self, batch_size=13)
|
||||
model_tester_decoder = Speech2Text2StandaloneDecoderModelTester(
|
||||
self, batch_size=13, d_model=32, max_position_embeddings=512
|
||||
)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_values,
|
||||
input_mask,
|
||||
) = encoder_config_and_inputs
|
||||
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"input_values": input_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
|
||||
@@ -241,11 +241,15 @@ class Speech2TextModelTester:
|
||||
decoder.save_pretrained(tmpdirname)
|
||||
decoder = Speech2TextDecoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
encoder_attention_mask = encoder._get_feature_vector_attention_mask(
|
||||
encoder_last_hidden_state.shape[1], inputs_dict["attention_mask"]
|
||||
)
|
||||
|
||||
last_hidden_state_2 = decoder(
|
||||
input_ids=inputs_dict["decoder_input_ids"],
|
||||
attention_mask=inputs_dict["decoder_attention_mask"],
|
||||
encoder_hidden_states=encoder_last_hidden_state,
|
||||
encoder_attention_mask=inputs_dict["attention_mask"],
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
|
||||
@@ -288,6 +292,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||
|
||||
# not implemented currently
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@@ -352,7 +357,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
else:
|
||||
seq_length = self.model_tester.seq_length
|
||||
|
||||
subsampled_seq_length = model._get_subsampled_output_lengths(seq_length)
|
||||
subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
@@ -402,8 +407,8 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
subsampled_encoder_seq_length = model._get_subsampled_output_lengths(encoder_seq_length)
|
||||
subsampled_encoder_key_length = model._get_subsampled_output_lengths(encoder_key_length)
|
||||
subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
|
||||
subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
211
tests/test_modeling_speech_to_text_2.py
Normal file
211
tests/test_modeling_speech_to_text_2.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Speech2Text model. """
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import Speech2Text2Config
|
||||
from transformers.testing_utils import is_torch_available, require_torch, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.speech_to_text_2.modeling_speech_to_text_2 import (
|
||||
Speech2Text2Decoder,
|
||||
Speech2Text2ForCausalLM,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Speech2Text2StandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
d_model=16,
|
||||
decoder_seq_length=7,
|
||||
is_training=True,
|
||||
is_decoder=True,
|
||||
use_attention_mask=True,
|
||||
use_cache=False,
|
||||
use_labels=True,
|
||||
decoder_start_token_id=2,
|
||||
decoder_ffn_dim=32,
|
||||
decoder_layers=4,
|
||||
decoder_attention_heads=4,
|
||||
max_position_embeddings=30,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.decoder_seq_length = decoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.decoder_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.hidden_size = d_model
|
||||
self.num_hidden_layers = decoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.num_attention_heads = decoder_attention_heads
|
||||
self.eos_token_id = eos_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.use_cache = use_cache
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.scope = None
|
||||
self.decoder_key_length = decoder_seq_length
|
||||
self.base_model_out_len = 2
|
||||
self.decoder_attention_idx = 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
config = Speech2Text2Config(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.d_model,
|
||||
decoder_layers=self.decoder_layers,
|
||||
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||
decoder_attention_heads=self.decoder_attention_heads,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
use_cache=self.use_cache,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
def create_and_check_decoder_model_past(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
config.use_cache = True
|
||||
model = Speech2Text2Decoder(config=config).to(torch_device).eval()
|
||||
input_ids = input_ids[:2]
|
||||
|
||||
input_ids[input_ids == 0] += 1
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
past_key_values = outputs["past_key_values"]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((2, 1), config.vocab_size - 1) + 1
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
print(next_input_ids)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
lm_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
def setUp(
|
||||
self,
|
||||
):
|
||||
self.model_tester = Speech2Text2StandaloneDecoderModelTester(self, is_training=False)
|
||||
self.config_tester = ConfigTester(self, config_class=Speech2Text2Config)
|
||||
|
||||
# not implemented currently
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# speech2text2 has no base model
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
# speech2text2 has no base model
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||
|
||||
# decoder cannot keep gradients
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
@@ -76,6 +76,8 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
model = "anton-l/wav2vec2-random-tiny-classifier"
|
||||
|
||||
# hack: dummy tokenizer is required to prevent pipeline from failing
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
audio_classifier = pipeline("audio-classification", model=model, tokenizer=tokenizer)
|
||||
|
||||
@@ -98,6 +100,8 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
import datasets
|
||||
|
||||
model = "superb/wav2vec2-base-superb-ks"
|
||||
|
||||
# hack: dummy tokenizer is required to prevent pipeline from failing
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
audio_classifier = pipeline("audio-classification", model=model, tokenizer=tokenizer)
|
||||
dataset = datasets.load_dataset("anton-l/superb_dummy", "ks", split="test")
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_datasets, require_torch, require_torchaudio, slow
|
||||
@@ -44,6 +46,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "C'est ce que j'ai fait à ce moment-là."})
|
||||
|
||||
@require_torch
|
||||
def test_torch_small_no_tokenizer_files(self):
|
||||
# test that model without tokenizer file cannot be loaded
|
||||
with pytest.raises(ValueError):
|
||||
pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/tiny-random-wav2vec2",
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
@require_datasets
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -67,6 +79,24 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
@require_datasets
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_speech_encoder_decoder(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/s2t-wav2vec2-large-en-de",
|
||||
feature_extractor="facebook/s2t-wav2vec2-large-en-de",
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_datasets
|
||||
|
||||
155
tests/test_tokenization_speech_to_text_2.py
Normal file
155
tests/test_tokenization_speech_to_text_2.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.models.speech_to_text_2 import Speech2Text2Tokenizer
|
||||
from transformers.models.speech_to_text_2.tokenization_speech_to_text_2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import is_pt_tf_cross_test
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = Speech2Text2Tokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = "<s> <pad> </s> <unk> here@@ a couple of@@ words for the vocab".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
||||
|
||||
self.assertEqual(vocab_keys[0], "<s>")
|
||||
self.assertEqual(vocab_keys[1], "<pad>")
|
||||
self.assertEqual(vocab_keys[-1], "vocab")
|
||||
self.assertEqual(len(vocab_keys), 12)
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.get_tokenizer().vocab_size, 12)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
tokenizer = Speech2Text2Tokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
# make sure @@ is correctly concatenated
|
||||
token_ids = [4, 6, 8, 7, 10] # ["here@@", "couple", "words", "of@@", "the"]
|
||||
output_string = tokenizer.decode(token_ids)
|
||||
|
||||
self.assertTrue(output_string == "herecouple words ofthe")
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_add_special_tokens(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_add_tokens_tokenizer(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_batch_encode_plus_batch_sequence_length(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_batch_encode_plus_overflowing_tokens(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_batch_encode_plus_padding(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_call(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_encode_plus_with_padding(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_internal_consistency(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_number_of_added_tokens(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_padding_to_max_length(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_padding_to_multiple_of(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_pickle_tokenizer(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_prepare_for_model(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_pretokenized_inputs(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_right_and_left_padding(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_save_and_load_tokenizer(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_special_tokens_mask(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_special_tokens_mask_input_pairs(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
def test_token_type_ids(self):
|
||||
pass
|
||||
|
||||
# currently tokenizer cannot do encoding, but just decoding
|
||||
@is_pt_tf_cross_test
|
||||
def test_batch_encode_plus_tensors(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user