Adding support for device_map directly in pipeline(..) function. (#17902)
* Adding support for `device_map` directly in `pipeline(..)` function. * Updating the docstring. * Adding a better docstring * Put back type hints. * Blacked. (`make fixup` didn't work ??!!)
This commit is contained in:
@@ -389,6 +389,8 @@ def pipeline(
|
||||
revision: Optional[str] = None,
|
||||
use_fast: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
device_map=None,
|
||||
torch_dtype=None,
|
||||
model_kwargs: Dict[str, Any] = None,
|
||||
pipeline_class: Optional[Any] = None,
|
||||
**kwargs
|
||||
@@ -472,6 +474,20 @@ def pipeline(
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
|
||||
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
|
||||
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
|
||||
information](https://huggingface.co/docs/accelerate/main/en/big_modeling#accelerate.cpu_offload)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Do not use `device_map` AND `device` at the same time as they will conflict
|
||||
|
||||
</Tip>
|
||||
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
|
||||
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
|
||||
model_kwargs:
|
||||
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
||||
**model_kwargs)` function.
|
||||
@@ -547,6 +563,20 @@ def pipeline(
|
||||
|
||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
||||
if device_map is not None:
|
||||
if "device_map" in model_kwargs:
|
||||
raise ValueError(
|
||||
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
|
||||
" arguments might conflict, use only one.)"
|
||||
)
|
||||
model_kwargs["device_map"] = device_map
|
||||
if torch_dtype is not None:
|
||||
if "torch_dtype" in model_kwargs:
|
||||
raise ValueError(
|
||||
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
|
||||
" arguments might conflict, use only one.)"
|
||||
)
|
||||
model_kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
# Config is the primordial information item.
|
||||
# Instantiate config if needed
|
||||
|
||||
@@ -15,7 +15,13 @@
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_accelerate,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
@@ -215,3 +221,63 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
handle_long_generation="hole",
|
||||
max_new_tokens=tokenizer.model_max_length + 10,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_small_model_pt_bloom_accelerate(self):
|
||||
import torch
|
||||
|
||||
# Classic `model_kwargs`
|
||||
pipe = pipeline(
|
||||
model="hf-internal-testing/tiny-random-bloom",
|
||||
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
|
||||
)
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"This is a test test test test test test test test test test test test test test test test"
|
||||
" test"
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"This is a test test test test test test test test test test test test test test test test"
|
||||
" test"
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# torch_dtype not necessary
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"This is a test test test test test test test test test test test test test test test test"
|
||||
" test"
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user