device agnostic pipelines testing (#27129)
* device agnostic pipelines testing * pass torch_device
This commit is contained in:
@@ -39,9 +39,10 @@ from transformers.testing_utils import (
|
||||
require_pyctcdecode,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
@@ -166,13 +167,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
def test_whisper_fp16(self):
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("Cuda is necessary for this test")
|
||||
speech_recognizer = pipeline(
|
||||
model="openai/whisper-base",
|
||||
device=0,
|
||||
device=torch_device,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
@@ -904,12 +903,12 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_wav2vec2_conformer_float16(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
|
||||
device="cuda:0",
|
||||
device=torch_device,
|
||||
torch_dtype=torch.float16,
|
||||
framework="pt",
|
||||
)
|
||||
@@ -1304,14 +1303,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, {"text": "XB"})
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_slow_unfinished_sequence(self):
|
||||
from transformers import GenerationConfig
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="vasista22/whisper-hindi-large-v2",
|
||||
device="cuda:0",
|
||||
device=torch_device,
|
||||
)
|
||||
# Original model wasn't trained with timestamps and has incorrect generation config
|
||||
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
||||
|
||||
Reference in New Issue
Block a user