From fa7f3cf336eb5d93cfaa7611723c299e7851fb02 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 5 Mar 2024 16:16:05 +0800 Subject: [PATCH] [tests] enable test_pipeline_accelerate_top_p on XPU (#29309) * use torch_device * Update tests/pipelines/test_pipelines_text_generation.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- tests/pipelines/test_pipelines_text_generation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 766f2a462a..ada04c7dbe 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -450,7 +450,9 @@ class TextGenerationPipelineTests(unittest.TestCase): def test_pipeline_accelerate_top_p(self): import torch - pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16) + pipe = pipeline( + model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16 + ) pipe("This is a test", do_sample=True, top_p=0.5) def test_pipeline_length_setting_warning(self):