make test_speculative_decoding_non_distil device-agnostic (#38010)
* make device-agnostic * use condition --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -2495,9 +2495,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
|
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_speculative_decoding_non_distil(self):
|
def test_speculative_decoding_non_distil(self):
|
||||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
torch_dtype = torch.float16 if torch_device in ["cuda", "xpu"] else torch.float32
|
||||||
model_id = "openai/whisper-large-v2"
|
model_id = "openai/whisper-large-v2"
|
||||||
model = WhisperForConditionalGeneration.from_pretrained(
|
model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||||
|
|||||||
Reference in New Issue
Block a user