fix condition where torch_dtype auto collides with model_kwargs. (#39054)
* fix condition where torch_dtype auto collides with model_kwargs. * update tests * update comment * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1005,10 +1005,18 @@ def pipeline(
|
|||||||
model_kwargs["device_map"] = device_map
|
model_kwargs["device_map"] = device_map
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
if "torch_dtype" in model_kwargs:
|
if "torch_dtype" in model_kwargs:
|
||||||
|
# If the user did not explicitly provide `torch_dtype` (i.e. the function default "auto" is still
|
||||||
|
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
|
||||||
|
# raising. This prevents false positives like providing `torch_dtype` only via `model_kwargs` while the
|
||||||
|
# top-level argument keeps its default value "auto".
|
||||||
|
if torch_dtype == "auto":
|
||||||
|
torch_dtype = None
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
|
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
|
||||||
" arguments might conflict, use only one.)"
|
" arguments might conflict, use only one.)"
|
||||||
)
|
)
|
||||||
|
if torch_dtype is not None:
|
||||||
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
||||||
torch_dtype = getattr(torch, torch_dtype)
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
model_kwargs["torch_dtype"] = torch_dtype
|
model_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|||||||
@@ -161,11 +161,11 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"input_text": "<image> What this is? Assistant: This is",
|
"input_text": "<image> What this is? Assistant: This is",
|
||||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"input_text": "<image> What this is? Assistant: This is",
|
"input_text": "<image> What this is? Assistant: This is",
|
||||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -441,11 +441,11 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
[{"generated_text": ("This is a test test test test test test")}],
|
[{"generated_text": ("This is a test test test test test test")}],
|
||||||
)
|
)
|
||||||
|
|
||||||
# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
|
# torch_dtype will be automatically set to torch.bfloat16 if not provided - check: https://github.com/huggingface/transformers/pull/38882
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False
|
model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False
|
||||||
)
|
)
|
||||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
|
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||||
out = pipe("This is a test")
|
out = pipe("This is a test")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
out,
|
out,
|
||||||
|
|||||||
Reference in New Issue
Block a user