[Sequence Feature Extraction] Add truncation (#12804)
* fix_torch_device_generate_test * remove @ * add truncate * finish * correct test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * clean tests * correct normalization for truncation * remove casting * up * save intermed * finish * finish * correct Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
98364ea74f
commit
f6e254474c
@@ -91,6 +91,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
|
||||
if equal_length:
|
||||
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
speech_inputs = [
|
||||
floats_list((x, self.feature_size))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
@@ -147,3 +148,26 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_trunc(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
inputs = feature_extractor(
|
||||
speech_inputs,
|
||||
padding="max_length",
|
||||
max_length=4,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
input_features = inputs.input_features
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1])
|
||||
_check_zero_mean_unit_variance(input_features[2])
|
||||
|
||||
Reference in New Issue
Block a user