Fix error in M4T feature extractor (#28340)
* fix M4T FE error when no attention mask * modify logic * add test * go back to initial test situation + add other tests
This commit is contained in:
@@ -171,6 +171,42 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_call_without_attention_mask(self):
|
||||
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
|
||||
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test attention mask when passing no attention mask to forward call
|
||||
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
|
||||
self.assertTrue("attention_mask" not in output)
|
||||
|
||||
# Test attention mask when no attention mask by default
|
||||
feature_extractor_args["return_attention_mask"] = False
|
||||
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
|
||||
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
|
||||
self.assertTrue("attention_mask" not in output)
|
||||
|
||||
def test_attention_mask(self):
|
||||
# test attention mask has the right output shape
|
||||
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test attention mask when passing it to forward call
|
||||
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np")
|
||||
input_features = output.input_features
|
||||
|
||||
attention_mask = output.attention_mask
|
||||
self.assertTrue(attention_mask.ndim == 2)
|
||||
self.assertTrue(attention_mask.shape[0] == 3)
|
||||
self.assertTrue(attention_mask.shape[-1] == input_features.shape[1])
|
||||
|
||||
@require_torch
|
||||
def test_call_torch(self):
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user