device agnostic pipelines testing (#27129)

* device agnostic pipelines testing

* pass torch_device
This commit is contained in:
Hz, Ji
2023-10-31 22:46:31 +08:00
committed by GitHub
parent 08fadc8085
commit f53041a753
10 changed files with 64 additions and 58 deletions

View File

@@ -39,9 +39,10 @@ from transformers.testing_utils import (
require_pyctcdecode,
require_tf,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torchaudio,
slow,
torch_device,
)
from .test_pipelines_common import ANY
@@ -166,13 +167,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
_ = speech_recognizer(waveform, return_timestamps="char")
@slow
@require_torch
@require_torch_accelerator
def test_whisper_fp16(self):
if not torch.cuda.is_available():
self.skipTest("Cuda is necessary for this test")
speech_recognizer = pipeline(
model="openai/whisper-base",
device=0,
device=torch_device,
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
@@ -904,12 +903,12 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@slow
@require_torch_gpu
@require_torch_accelerator
def test_wav2vec2_conformer_float16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
device="cuda:0",
device=torch_device,
torch_dtype=torch.float16,
framework="pt",
)
@@ -1304,14 +1303,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": "XB"})
@slow
@require_torch_gpu
@require_torch_accelerator
def test_slow_unfinished_sequence(self):
from transformers import GenerationConfig
pipe = pipeline(
"automatic-speech-recognition",
model="vasista22/whisper-hindi-large-v2",
device="cuda:0",
device=torch_device,
)
# Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")