fix condition where torch_dtype auto collides with model_kwargs. (#39054)
* fix condition where torch_dtype auto collides with model_kwargs. * update tests * update comment * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -161,11 +161,11 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
[
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
|
||||
},
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -441,11 +441,11 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
[{"generated_text": ("This is a test test test test test test")}],
|
||||
)
|
||||
|
||||
# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
|
||||
# torch_dtype will be automatically set to torch.bfloat16 if not provided - check: https://github.com/huggingface/transformers/pull/38882
|
||||
pipe = pipeline(
|
||||
model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False
|
||||
)
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
|
||||
Reference in New Issue
Block a user