Adding support for device_map directly in pipeline(..) function. (#17902)

* Adding support for `device_map` directly in `pipeline(..)` function.

* Updating the docstring.

* Adding a better docstring

* Put back type hints.

* Blacked. (`make fixup` didn't work ??!!)
This commit is contained in:
Nicolas Patry
2022-07-15 15:54:26 +02:00
committed by GitHub
parent fca66ec4ef
commit ccc0897804
2 changed files with 97 additions and 1 deletions

View File

@@ -15,7 +15,13 @@
import unittest
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
from transformers.testing_utils import (
is_pipeline_test,
require_accelerate,
require_tf,
require_torch,
require_torch_gpu,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@@ -215,3 +221,63 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
@require_torch
@require_accelerate
@require_torch_gpu
def test_small_model_pt_bloom_accelerate(self):
import torch
# Classic `model_kwargs`
pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom",
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
)
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
[
{
"generated_text": (
"This is a test test test test test test test test test test test test test test test test"
" test"
)
}
],
)
# Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
[
{
"generated_text": (
"This is a test test test test test test test test test test test test test test test test"
" test"
)
}
],
)
# torch_dtype not necessary
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
[
{
"generated_text": (
"This is a test test test test test test test test test test test test test test test test"
" test"
)
}
],
)