Fix device mismatch error in Whisper model during feature extraction (#35866)
* Fix device mismatch error in whisper feature extraction * Set default device * Address code review feedback --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
@@ -298,8 +298,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(3)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
with torch.device("cuda"):
|
||||
input_speech = self._load_datasamples(3)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEqual(input_features.shape, (3, 80, 3000))
|
||||
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user