Adding support for device_map directly in pipeline(..) function. (#17902)
* Adding support for `device_map` directly in `pipeline(..)` function. * Updating the docstring. * Adding a better docstring * Put back type hints. * Blacked. (`make fixup` didn't work ??!!)
This commit is contained in:
@@ -389,6 +389,8 @@ def pipeline(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
use_fast: bool = True,
|
use_fast: bool = True,
|
||||||
use_auth_token: Optional[Union[str, bool]] = None,
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
|
device_map=None,
|
||||||
|
torch_dtype=None,
|
||||||
model_kwargs: Dict[str, Any] = None,
|
model_kwargs: Dict[str, Any] = None,
|
||||||
pipeline_class: Optional[Any] = None,
|
pipeline_class: Optional[Any] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -472,6 +474,20 @@ def pipeline(
|
|||||||
use_auth_token (`str` or *bool*, *optional*):
|
use_auth_token (`str` or *bool*, *optional*):
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||||
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
|
||||||
|
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
|
||||||
|
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
|
||||||
|
information](https://huggingface.co/docs/accelerate/main/en/big_modeling#accelerate.cpu_offload)
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Do not use `device_map` AND `device` at the same time as they will conflict
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
|
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
|
||||||
|
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
||||||
**model_kwargs)` function.
|
**model_kwargs)` function.
|
||||||
@@ -547,6 +563,20 @@ def pipeline(
|
|||||||
|
|
||||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
||||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
||||||
|
if device_map is not None:
|
||||||
|
if "device_map" in model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
|
||||||
|
" arguments might conflict, use only one.)"
|
||||||
|
)
|
||||||
|
model_kwargs["device_map"] = device_map
|
||||||
|
if torch_dtype is not None:
|
||||||
|
if "torch_dtype" in model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
|
||||||
|
" arguments might conflict, use only one.)"
|
||||||
|
)
|
||||||
|
model_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|
||||||
# Config is the primordial information item.
|
# Config is the primordial information item.
|
||||||
# Instantiate config if needed
|
# Instantiate config if needed
|
||||||
|
|||||||
@@ -15,7 +15,13 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
||||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
require_accelerate,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
@@ -215,3 +221,63 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
|||||||
handle_long_generation="hole",
|
handle_long_generation="hole",
|
||||||
max_new_tokens=tokenizer.model_max_length + 10,
|
max_new_tokens=tokenizer.model_max_length + 10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_small_model_pt_bloom_accelerate(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Classic `model_kwargs`
|
||||||
|
pipe = pipeline(
|
||||||
|
model="hf-internal-testing/tiny-random-bloom",
|
||||||
|
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
|
||||||
|
)
|
||||||
|
self.assertEqual(pipe.model.device, torch.device(0))
|
||||||
|
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||||
|
out = pipe("This is a test")
|
||||||
|
self.assertEqual(
|
||||||
|
out,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": (
|
||||||
|
"This is a test test test test test test test test test test test test test test test test"
|
||||||
|
" test"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
|
||||||
|
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
|
||||||
|
self.assertEqual(pipe.model.device, torch.device(0))
|
||||||
|
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||||
|
out = pipe("This is a test")
|
||||||
|
self.assertEqual(
|
||||||
|
out,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": (
|
||||||
|
"This is a test test test test test test test test test test test test test test test test"
|
||||||
|
" test"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# torch_dtype not necessary
|
||||||
|
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
|
||||||
|
self.assertEqual(pipe.model.device, torch.device(0))
|
||||||
|
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||||
|
out = pipe("This is a test")
|
||||||
|
self.assertEqual(
|
||||||
|
out,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": (
|
||||||
|
"This is a test test test test test test test test test test test test test test test test"
|
||||||
|
" test"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user