[Wav2Vec2FeatureExtractor] Fix extractor.pad() dtype backwards compatibility (#13693)

* Force dtype, add tests

* Local torch imports

* Remove unused logic (always ndarray)
This commit is contained in:
Anton Lozhkov
2021-09-22 12:02:54 +03:00
committed by GitHub
parent 8e908c8c74
commit 75f6641eaf
3 changed files with 29 additions and 17 deletions

View File

@@ -235,3 +235,16 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
# make sure that if max_length < longest -> then pad to max_length
self.assertEqual(input_features.shape, (3, 6, 24))
def test_double_precision_pad(self):
import torch
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()
for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)

View File

@@ -196,6 +196,20 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
# make sure that if max_length > longest -> then pad to longest
self.assertTrue(input_values.shape == (3, 1200))
@require_torch
def test_double_precision_pad(self):
import torch
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()
for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_values.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
@slow
@require_torch
def test_pretrained_checkpoints_are_set_correctly(self):