[Wav2Vec2FeatureExtractor] Fix extractor.pad() dtype backwards compatibility (#13693)
* Force dtype, add tests * Local torch imports * Remove unused logic (always ndarray)
This commit is contained in:
@@ -187,23 +187,6 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
|||||||
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
|
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
|
||||||
|
|
||||||
required_input = processed_features[self.model_input_names[0]]
|
required_input = processed_features[self.model_input_names[0]]
|
||||||
if required_input and not isinstance(required_input[0], np.ndarray):
|
|
||||||
# truncation
|
|
||||||
processed_features = self._truncate(
|
|
||||||
processed_features,
|
|
||||||
max_length=max_length,
|
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
truncation=truncation,
|
|
||||||
)
|
|
||||||
# padding
|
|
||||||
processed_features = self._pad(
|
|
||||||
processed_features,
|
|
||||||
max_length=max_length,
|
|
||||||
padding_strategy=padding_strategy,
|
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_attention_mask=return_attention_mask,
|
|
||||||
)
|
|
||||||
return BatchFeature(processed_features, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
batch_size = len(required_input)
|
batch_size = len(required_input)
|
||||||
if not all(len(v) == batch_size for v in processed_features.values()):
|
if not all(len(v) == batch_size for v in processed_features.values()):
|
||||||
@@ -240,6 +223,8 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
|||||||
for key, value in outputs.items():
|
for key, value in outputs.items():
|
||||||
if key not in batch_outputs:
|
if key not in batch_outputs:
|
||||||
batch_outputs[key] = []
|
batch_outputs[key] = []
|
||||||
|
if value.dtype is np.dtype(np.float64):
|
||||||
|
value = value.astype(np.float32)
|
||||||
batch_outputs[key].append(value)
|
batch_outputs[key].append(value)
|
||||||
|
|
||||||
return BatchFeature(batch_outputs, tensor_type=return_tensors)
|
return BatchFeature(batch_outputs, tensor_type=return_tensors)
|
||||||
|
|||||||
@@ -235,3 +235,16 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
|||||||
|
|
||||||
# make sure that if max_length < longest -> then pad to max_length
|
# make sure that if max_length < longest -> then pad to max_length
|
||||||
self.assertEqual(input_features.shape, (3, 6, 24))
|
self.assertEqual(input_features.shape, (3, 6, 24))
|
||||||
|
|
||||||
|
def test_double_precision_pad(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
|
||||||
|
py_speech_inputs = np_speech_inputs.tolist()
|
||||||
|
|
||||||
|
for inputs in [py_speech_inputs, np_speech_inputs]:
|
||||||
|
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
|
||||||
|
self.assertTrue(np_processed.input_features.dtype == np.float32)
|
||||||
|
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
||||||
|
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||||
|
|||||||
@@ -196,6 +196,20 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
|||||||
# make sure that if max_length > longest -> then pad to longest
|
# make sure that if max_length > longest -> then pad to longest
|
||||||
self.assertTrue(input_values.shape == (3, 1200))
|
self.assertTrue(input_values.shape == (3, 1200))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_double_precision_pad(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
np_speech_inputs = np.random.rand(100).astype(np.float64)
|
||||||
|
py_speech_inputs = np_speech_inputs.tolist()
|
||||||
|
|
||||||
|
for inputs in [py_speech_inputs, np_speech_inputs]:
|
||||||
|
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
|
||||||
|
self.assertTrue(np_processed.input_values.dtype == np.float32)
|
||||||
|
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
|
||||||
|
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user