[Pipeline] fix failing bloom pipeline test (#20778)
fix failing `pipeline` test
This commit is contained in:
@@ -284,10 +284,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# torch_dtype not necessary
|
# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
|
||||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
|
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
|
||||||
self.assertEqual(pipe.model.device, torch.device(0))
|
self.assertEqual(pipe.model.device, torch.device(0))
|
||||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
|
||||||
out = pipe("This is a test")
|
out = pipe("This is a test")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
out,
|
out,
|
||||||
|
|||||||
Reference in New Issue
Block a user