[SequenceFeatureExtractor] Rewrite padding logic from pure python to numpy (#13650)
* Test np padding * Pass feature extraction tests * Update type hints * Fix flaky integration tests * Try a more stable waveform * Add to_numpy jax support * int32 attention masks * Refactor normalization tests
This commit is contained in:
@@ -110,6 +110,10 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
|
||||
|
||||
def _check_zero_mean_unit_variance(self, input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -137,17 +141,9 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
# TODO(Patrick, Suraj, Anton) - It's surprising that "non-padded/non-numpified" padding
|
||||
# results in quite inaccurate variance computation after (see 5e-1 tolerance)
|
||||
# Issue is filed and PR is underway: https://github.com/huggingface/transformers/issues/13539
|
||||
# paddings = ["longest", "max_length", "do_not_pad"]
|
||||
# max_lengths = [None, 16, None]
|
||||
# var_tolerances = [1e-3, 1e-3, 5e-1]
|
||||
paddings = ["longest", "max_length"]
|
||||
max_lengths = [None, 16]
|
||||
var_tolerances = [1e-3, 1e-3]
|
||||
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 16, None]
|
||||
for max_length, padding in zip(max_lengths, paddings):
|
||||
inputs = feature_extractor(
|
||||
speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True
|
||||
)
|
||||
@@ -155,28 +151,17 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
|
||||
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
|
||||
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
|
||||
self._check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]])
|
||||
self._check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]])
|
||||
self._check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]])
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_np(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
# TODO(Patrick, Suraj, Anton) - It's surprising that "non-padded/non-numpified" padding
|
||||
# results in quite inaccurate variance computation after (see 5e-1 tolerance)
|
||||
# Issue is filed and PR is underway: https://github.com/huggingface/transformers/issues/13539
|
||||
# paddings = ["longest", "max_length", "do_not_pad"]
|
||||
# max_lengths = [None, 16, None]
|
||||
# var_tolerances = [1e-3, 1e-3, 5e-1]
|
||||
paddings = ["longest", "max_length"]
|
||||
max_lengths = [None, 16]
|
||||
var_tolerances = [1e-3, 1e-3]
|
||||
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 16, None]
|
||||
for max_length, padding in zip(max_lengths, paddings):
|
||||
inputs = feature_extractor(
|
||||
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
|
||||
)
|
||||
@@ -184,15 +169,11 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
|
||||
self._check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]])
|
||||
self.assertTrue(input_features[0][fbank_feat_lengths[0] :].sum() < 1e-6)
|
||||
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
|
||||
self._check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]])
|
||||
self.assertTrue(input_features[0][fbank_feat_lengths[1] :].sum() < 1e-6)
|
||||
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
|
||||
self._check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]])
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_trunc_max_length(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -209,13 +190,9 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1])
|
||||
_check_zero_mean_unit_variance(input_features[2])
|
||||
self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
self._check_zero_mean_unit_variance(input_features[1])
|
||||
self._check_zero_mean_unit_variance(input_features[2])
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_trunc_longest(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -232,13 +209,9 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
_check_zero_mean_unit_variance(input_features[2])
|
||||
self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
self._check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
self._check_zero_mean_unit_variance(input_features[2])
|
||||
|
||||
# make sure that if max_length < longest -> then pad to max_length
|
||||
self.assertEqual(input_features.shape, (3, 4, 24))
|
||||
@@ -256,9 +229,9 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
_check_zero_mean_unit_variance(input_features[2])
|
||||
self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
self._check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
self._check_zero_mean_unit_variance(input_features[2])
|
||||
|
||||
# make sure that if max_length < longest -> then pad to max_length
|
||||
self.assertEqual(input_features.shape, (3, 6, 24))
|
||||
|
||||
@@ -102,6 +102,10 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Wav2Vec2FeatureExtractionTester(self)
|
||||
|
||||
def _check_zero_mean_unit_variance(self, input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -130,15 +134,11 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0][:800])
|
||||
self._check_zero_mean_unit_variance(input_values[0][:800])
|
||||
self.assertTrue(input_values[0][800:].sum() < 1e-6)
|
||||
_check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
self._check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
self.assertTrue(input_values[0][1000:].sum() < 1e-6)
|
||||
_check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
self._check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -152,13 +152,9 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0][:800])
|
||||
_check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
_check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
self._check_zero_mean_unit_variance(input_values[0][:800])
|
||||
self._check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
self._check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_max_length(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -168,13 +164,9 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
self._check_zero_mean_unit_variance(input_values[0, :800])
|
||||
self._check_zero_mean_unit_variance(input_values[1])
|
||||
self._check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
@@ -184,13 +176,9 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
self._check_zero_mean_unit_variance(input_values[0, :800])
|
||||
self._check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
self._check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
# make sure that if max_length < longest -> then pad to max_length
|
||||
self.assertTrue(input_values.shape == (3, 1000))
|
||||
@@ -201,9 +189,9 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
self._check_zero_mean_unit_variance(input_values[0, :800])
|
||||
self._check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
self._check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
# make sure that if max_length > longest -> then pad to longest
|
||||
self.assertTrue(input_values.shape == (3, 1200))
|
||||
|
||||
@@ -724,7 +724,7 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
|
||||
return batch
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
ds = ds.select(range(num_samples)).map(map_to_array)
|
||||
ds = ds.sort("id").select(range(num_samples)).map(map_to_array)
|
||||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
@@ -740,7 +740,9 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
]
|
||||
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_generation_librispeech_batched(self):
|
||||
@@ -759,10 +761,10 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
|
||||
generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the titleing cloth that was the only garment he wore",
|
||||
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
||||
"his instant of panic was followed by a small sharp blow high on his chest",
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
"nor is mister cultar's manner less interesting than his matter",
|
||||
"he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind",
|
||||
"he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it but little of rocky ithaca",
|
||||
]
|
||||
|
||||
self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
@@ -42,9 +42,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
tokenizer="facebook/s2t-small-mustc-en-fr-st",
|
||||
framework="pt",
|
||||
)
|
||||
waveform = np.zeros((34000,))
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "C'est ce que j'ai fait à ce moment-là."})
|
||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||
|
||||
@require_torch
|
||||
def test_torch_small_no_tokenizer_files(self):
|
||||
@@ -68,14 +68,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
tokenizer="facebook/wav2vec2-base-960h",
|
||||
framework="pt",
|
||||
)
|
||||
waveform = np.zeros((34000,))
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": ""})
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
@@ -92,8 +92,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})
|
||||
|
||||
@@ -110,16 +110,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
|
||||
asr = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
waveform = np.zeros((34000,))
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
output = asr(waveform)
|
||||
self.assertEqual(output, {"text": ""})
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = asr(filename)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
filename = ds[0]["file"]
|
||||
filename = ds[40]["file"]
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
output = asr(data)
|
||||
@@ -139,17 +139,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
|
||||
asr = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
waveform = np.zeros((34000,))
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
|
||||
output = asr(waveform)
|
||||
self.assertEqual(output, {"text": "E questo è il motivo per cui non ci siamo mai incontrati."})
|
||||
self.assertEqual(output, {"text": "(Applausi)"})
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = asr(filename)
|
||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||
|
||||
filename = ds[0]["file"]
|
||||
filename = ds[40]["file"]
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
output = asr(data)
|
||||
|
||||
@@ -372,7 +372,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
|
||||
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
|
||||
|
||||
self.assertTrue(abs(input_np.astype(np.float32).sum() - input_pt.numpy().sum()) < 1e-2)
|
||||
self.assertTrue(abs(input_np.astype(np.float32).sum() - input_pt.numpy().astype(np.float32).sum()) < 1e-2)
|
||||
|
||||
@require_tf
|
||||
def test_padding_accepts_tensors_tf(self):
|
||||
@@ -385,7 +385,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
|
||||
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]
|
||||
|
||||
self.assertTrue(abs(input_np.astype(np.float32).sum() - input_tf.numpy().sum()) < 1e-2)
|
||||
self.assertTrue(abs(input_np.astype(np.float32).sum() - input_tf.numpy().astype(np.float32).sum()) < 1e-2)
|
||||
|
||||
def test_attention_mask(self):
|
||||
feat_dict = self.feat_extract_dict
|
||||
|
||||
Reference in New Issue
Block a user