[tests] enable test_pipeline_accelerate_top_p on XPU (#29309)
* use torch_device * Update tests/pipelines/test_pipelines_text_generation.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -450,7 +450,9 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
def test_pipeline_accelerate_top_p(self):
|
||||
import torch
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
||||
pipe = pipeline(
|
||||
model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||
|
||||
def test_pipeline_length_setting_warning(self):
|
||||
|
||||
Reference in New Issue
Block a user