[Sequence Feature Extraction] Add truncation (#12804)

* fix_torch_device_generate_test

* remove @

* add truncate

* finish

* correct test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* clean tests

* correct normalization for truncation

* remove casting

* up

* save intermed

* finish

* finish

* correct

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-07-23 17:53:30 +02:00
committed by GitHub
parent 98364ea74f
commit f6e254474c
8 changed files with 370 additions and 38 deletions

View File

@@ -91,6 +91,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
@@ -147,3 +148,26 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
def test_cepstral_mean_and_variance_normalization_trunc(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(
speech_inputs,
padding="max_length",
max_length=4,
truncation=True,
return_tensors="np",
return_attention_mask=True,
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1])
_check_zero_mean_unit_variance(input_features[2])

View File

@@ -83,6 +83,7 @@ class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
if equal_length:
speech_inputs = floats_list((self.batch_size, self.max_seq_length))
else:
# make sure that inputs increase in size
speech_inputs = [
_flatten(floats_list((x, self.feature_size)))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
@@ -122,7 +123,7 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
def test_zero_mean_unit_variance_normalization(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(speech_inputs, padding="longest")
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
@@ -133,6 +134,22 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])
def test_zero_mean_unit_variance_normalization_trunc(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(
speech_inputs, truncation=True, max_length=1000, padding="max_length", return_tensors="np"
)
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0, :800])
_check_zero_mean_unit_variance(input_values[1])
_check_zero_mean_unit_variance(input_values[2])
@slow
@require_torch
def test_pretrained_checkpoints_are_set_correctly(self):

View File

@@ -715,7 +715,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
@@ -737,7 +737,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)

View File

@@ -126,12 +126,17 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
feature_size = self.feat_extract_tester.feature_size
# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_1 = feat_extract.pad(processed_features, padding=False)
input_1 = input_1[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")
input_2 = input_2[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))
input_3 = input_3[input_name]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")
input_4 = input_4[input_name]
# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
@@ -139,7 +144,8 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]
)
input_5 = input_5[input_name]
self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
@@ -154,18 +160,25 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)
input_6 = input_6[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)
input_7 = input_7[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
)
input_8 = input_8[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]
)
input_9 = input_9[input_name]
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))
@@ -205,12 +218,149 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
< 1e-3
)
def _check_truncation(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True
def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False
for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
# truncate to smallest
input_1 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[0]), truncation=True
)
input_1 = input_1[input_name]
input_2 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[0]))
input_2 = input_2[input_name]
self.assertTrue(_inputs_have_equal_length(input_1))
self.assertFalse(_inputs_have_equal_length(input_2))
# truncate to smallest with np
input_3 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[0]),
return_tensors="np",
truncation=True,
)
input_3 = input_3[input_name]
input_4 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[0]), return_tensors="np"
)
input_4 = input_4[input_name]
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(input_3.shape[1] == len(speech_inputs[0]))
# since truncation forces padding to be smaller than longest input
# function can't return `np.ndarray`, but has to return list
self.assertFalse(_inputs_have_equal_length(input_4))
# truncate to middle
input_5 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[1]),
truncation=True,
return_tensors="np",
)
input_5 = input_5[input_name]
input_6 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[1]), truncation=True
)
input_6 = input_6[input_name]
input_7 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[1]), return_tensors="np"
)
input_7 = input_7[input_name]
self.assertTrue(input_5.shape[1] == len(speech_inputs[1]))
self.assertTrue(_inputs_have_equal_length(input_5))
self.assertTrue(_inputs_have_equal_length(input_6))
self.assertTrue(_inputs_are_equal(input_5, input_6))
# since truncation forces padding to be smaller than longest input
# function can't return `np.ndarray`, but has to return list
self.assertFalse(_inputs_have_equal_length(input_7))
self.assertTrue(len(input_7[-1]) == len(speech_inputs[-1]))
# padding has to be max_length when setting `truncation=True`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, truncation=True)[input_name]
# padding has to be max_length when setting `truncation=True`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="longest", truncation=True)[input_name]
# padding has to be max_length when setting `truncation=True`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="longest", truncation=True)[input_name]
# max_length parameter has to be provided when setting `truncation=True` and padding="max_length"
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length", truncation=True)[input_name]
# test truncation for `pad_to_multiple_of` for List[int] + numpy
pad_to_multiple_of = 12
input_8 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[0]),
pad_to_multiple_of=pad_to_multiple_of,
truncation=True,
)
input_8 = input_8[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[0]),
pad_to_multiple_of=pad_to_multiple_of,
)
input_9 = input_9[input_name]
# retrieve expected_length as multiple of pad_to_multiple_of
expected_length = len(speech_inputs[0])
if expected_length % pad_to_multiple_of != 0:
expected_length = ((len(speech_inputs[0]) // pad_to_multiple_of) + 1) * pad_to_multiple_of
self.assertTrue(len(input_8[0]) == expected_length)
self.assertTrue(_inputs_have_equal_length(input_8))
self.assertFalse(_inputs_have_equal_length(input_9))
def test_padding_from_list(self):
self._check_padding(numpify=False)
def test_padding_from_array(self):
self._check_padding(numpify=True)
def test_truncation_from_list(self):
self._check_truncation(numpify=False)
def test_truncation_from_array(self):
self._check_truncation(numpify=True)
@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
@@ -251,3 +401,25 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
def test_attention_mask_with_truncation(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]
processed = BatchFeature({input_name: speech_inputs})
max_length = min(input_lenghts)
processed_pad = feat_extract.pad(
processed, padding="max_length", max_length=max_length, truncation=True, return_tensors="np"
)
self.assertIn("attention_mask", processed_pad)
self.assertListEqual(
list(processed_pad.attention_mask.shape), list((processed_pad[input_name].shape[0], max_length))
)
self.assertListEqual(
processed_pad.attention_mask[:, :max_length].sum(-1).tolist(), [max_length for x in speech_inputs]
)