[tests] enable test_pipeline_accelerate_top_p on XPU (#29309)

* use torch_device

* Update tests/pipelines/test_pipelines_text_generation.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Fanli Lin
2024-03-05 16:16:05 +08:00
committed by GitHub
parent ebccb09169
commit fa7f3cf336

View File

@@ -450,7 +450,9 @@ class TextGenerationPipelineTests(unittest.TestCase):
def test_pipeline_accelerate_top_p(self): def test_pipeline_accelerate_top_p(self):
import torch import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16) pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16
)
pipe("This is a test", do_sample=True, top_p=0.5) pipe("This is a test", do_sample=True, top_p=0.5)
def test_pipeline_length_setting_warning(self): def test_pipeline_length_setting_warning(self):