[tests] enable test_pipeline_accelerate_top_p on XPU (#29309)
* use torch_device * Update tests/pipelines/test_pipelines_text_generation.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -450,7 +450,9 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
def test_pipeline_accelerate_top_p(self):
|
def test_pipeline_accelerate_top_p(self):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
pipe = pipeline(
|
||||||
|
model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||||
|
|
||||||
def test_pipeline_length_setting_warning(self):
|
def test_pipeline_length_setting_warning(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user