[Sequence Feature Extraction] Add truncation (#12804)

* fix_torch_device_generate_test

* remove @

* add truncate

* finish

* correct test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* clean tests

* correct normalization for truncation

* remove casting

* up

* save intermed

* finish

* finish

* correct

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-07-23 17:53:30 +02:00
committed by GitHub
parent 98364ea74f
commit f6e254474c
8 changed files with 370 additions and 38 deletions

View File

@@ -91,6 +91,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
@@ -147,3 +148,26 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
def test_cepstral_mean_and_variance_normalization_trunc(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(
speech_inputs,
padding="max_length",
max_length=4,
truncation=True,
return_tensors="np",
return_attention_mask=True,
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1])
_check_zero_mean_unit_variance(input_features[2])