Adding support for device_map directly in pipeline(..) function. (#17902)
* Adding support for `device_map` directly in `pipeline(..)` function. * Updating the docstring. * Adding a better docstring * Put back type hints. * Blacked. (`make fixup` didn't work ??!!)
This commit is contained in:
@@ -15,7 +15,13 @@
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_accelerate,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
@@ -215,3 +221,63 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
handle_long_generation="hole",
|
||||
max_new_tokens=tokenizer.model_max_length + 10,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_small_model_pt_bloom_accelerate(self):
|
||||
import torch
|
||||
|
||||
# Classic `model_kwargs`
|
||||
pipe = pipeline(
|
||||
model="hf-internal-testing/tiny-random-bloom",
|
||||
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
|
||||
)
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"This is a test test test test test test test test test test test test test test test test"
|
||||
" test"
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"This is a test test test test test test test test test test test test test test test test"
|
||||
" test"
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# torch_dtype not necessary
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"This is a test test test test test test test test test test test test test test test test"
|
||||
" test"
|
||||
)
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user