From d973e62fdd86d64259f87debc46bbcbf6c7e5de2 Mon Sep 17 00:00:00 2001 From: vb Date: Thu, 26 Jun 2025 14:52:57 +0200 Subject: [PATCH] fix condition where torch_dtype auto collides with model_kwargs. (#39054) * fix condition where torch_dtype auto collides with model_kwargs. * update tests * update comment * fix --------- Co-authored-by: ydshieh --- src/transformers/pipelines/__init__.py | 22 +++++++++++++------ .../test_pipelines_image_text_to_text.py | 4 ++-- .../test_pipelines_text_generation.py | 4 ++-- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index fe829d51ea..2b433d9c7f 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -1005,13 +1005,21 @@ def pipeline( model_kwargs["device_map"] = device_map if torch_dtype is not None: if "torch_dtype" in model_kwargs: - raise ValueError( - 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' - " arguments might conflict, use only one.)" - ) - if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): - torch_dtype = getattr(torch, torch_dtype) - model_kwargs["torch_dtype"] = torch_dtype + # If the user did not explicitly provide `torch_dtype` (i.e. the function default "auto" is still + # present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of + # raising. This prevents false positives like providing `torch_dtype` only via `model_kwargs` while the + # top-level argument keeps its default value "auto". + if torch_dtype == "auto": + torch_dtype = None + else: + raise ValueError( + 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' + " arguments might conflict, use only one.)" + ) + if torch_dtype is not None: + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + model_kwargs["torch_dtype"] = torch_dtype model_name = model if isinstance(model, str) else None diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index 5e38130a11..781fbad8a9 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -161,11 +161,11 @@ class ImageTextToTextPipelineTests(unittest.TestCase): [ { "input_text": " What this is? Assistant: This is", - "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they", }, { "input_text": " What this is? Assistant: This is", - "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they", }, ], ) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index dd13219557..d92a3aefec 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -441,11 +441,11 @@ class TextGenerationPipelineTests(unittest.TestCase): [{"generated_text": ("This is a test test test test test test")}], ) - # torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602 + # torch_dtype will be automatically set to torch.bfloat16 if not provided - check: https://github.com/huggingface/transformers/pull/38882 pipe = pipeline( model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False ) - self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32) + self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) out = pipe("This is a test") self.assertEqual( out,