From ccc089780415445768bcfd3ac4418cec20353484 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Jul 2022 15:54:26 +0200 Subject: [PATCH] Adding support for `device_map` directly in `pipeline(..)` function. (#17902) * Adding support for `device_map` directly in `pipeline(..)` function. * Updating the docstring. * Adding a better docstring * Put back type hints. * Blacked. (`make fixup` didn't work ??!!) --- src/transformers/pipelines/__init__.py | 30 ++++++++ .../test_pipelines_text_generation.py | 68 ++++++++++++++++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index e0c754a85d..e563b28427 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -389,6 +389,8 @@ def pipeline( revision: Optional[str] = None, use_fast: bool = True, use_auth_token: Optional[Union[str, bool]] = None, + device_map=None, + torch_dtype=None, model_kwargs: Dict[str, Any] = None, pipeline_class: Optional[Any] = None, **kwargs @@ -472,6 +474,20 @@ def pipeline( use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). + device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*): + Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set + `device_map="auto"` to compute the most optimized `device_map` automatically. [More + information](https://huggingface.co/docs/accelerate/main/en/big_modeling#accelerate.cpu_offload) + + + + Do not use `device_map` AND `device` at the same time as they will conflict + + + + torch_dtype (`str` or `torch.dtype`, *optional*): + Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model + (`torch.float16`, `torch.bfloat16`, ... or `"auto"`). model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(..., **model_kwargs)` function. @@ -547,6 +563,20 @@ def pipeline( # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) + if device_map is not None: + if "device_map" in model_kwargs: + raise ValueError( + 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those' + " arguments might conflict, use only one.)" + ) + model_kwargs["device_map"] = device_map + if torch_dtype is not None: + if "torch_dtype" in model_kwargs: + raise ValueError( + 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' + " arguments might conflict, use only one.)" + ) + model_kwargs["torch_dtype"] = torch_dtype # Config is the primordial information item. # Instantiate config if needed diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 929e2732f0..a26ed56d4c 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -15,7 +15,13 @@ import unittest from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline -from transformers.testing_utils import is_pipeline_test, require_tf, require_torch +from transformers.testing_utils import ( + is_pipeline_test, + require_accelerate, + require_tf, + require_torch, + require_torch_gpu, +) from .test_pipelines_common import ANY, PipelineTestCaseMeta @@ -215,3 +221,63 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM handle_long_generation="hole", max_new_tokens=tokenizer.model_max_length + 10, ) + + @require_torch + @require_accelerate + @require_torch_gpu + def test_small_model_pt_bloom_accelerate(self): + import torch + + # Classic `model_kwargs` + pipe = pipeline( + model="hf-internal-testing/tiny-random-bloom", + model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16}, + ) + self.assertEqual(pipe.model.device, torch.device(0)) + self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) + out = pipe("This is a test") + self.assertEqual( + out, + [ + { + "generated_text": ( + "This is a test test test test test test test test test test test test test test test test" + " test" + ) + } + ], + ) + + # Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.) + pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16) + self.assertEqual(pipe.model.device, torch.device(0)) + self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) + out = pipe("This is a test") + self.assertEqual( + out, + [ + { + "generated_text": ( + "This is a test test test test test test test test test test test test test test test test" + " test" + ) + } + ], + ) + + # torch_dtype not necessary + pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto") + self.assertEqual(pipe.model.device, torch.device(0)) + self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) + out = pipe("This is a test") + self.assertEqual( + out, + [ + { + "generated_text": ( + "This is a test test test test test test test test test test test test test test test test" + " test" + ) + } + ], + )