Fix error in M4T feature extractor (#28340)

* fix M4T FE error when no attention mask

* modify logic

* add test

* go back to initial test situation + add other tests
This commit is contained in:
Yoach Lacombe
2024-01-04 17:40:53 +01:00
committed by GitHub
parent 4a66c0d952
commit 35e9d2b223
2 changed files with 44 additions and 3 deletions

View File

@@ -171,6 +171,42 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_call_without_attention_mask(self):
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test attention mask when passing no attention mask to forward call
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
self.assertTrue("attention_mask" not in output)
# Test attention mask when no attention mask by default
feature_extractor_args["return_attention_mask"] = False
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
self.assertTrue("attention_mask" not in output)
def test_attention_mask(self):
# test attention mask has the right output shape
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test attention mask when passing it to forward call
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np")
input_features = output.input_features
attention_mask = output.attention_mask
self.assertTrue(attention_mask.ndim == 2)
self.assertTrue(attention_mask.shape[0] == 3)
self.assertTrue(attention_mask.shape[-1] == input_features.shape[1])
@require_torch
def test_call_torch(self):
import torch