Fix device mismatch error in Whisper model during feature extraction (#35866)

* Fix device mismatch error in whisper feature extraction

* Set default device

* Address code review feedback

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
Sumit Vij
2025-02-04 03:23:08 -08:00
committed by GitHub
parent 9afb904b15
commit bc9a6d8302
2 changed files with 7 additions and 11 deletions

View File

@@ -298,8 +298,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
)
# fmt: on
input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
with torch.device("cuda"):
input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (3, 80, 3000))
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)