Use Accelerate in from_pretrained for big model inference (#17341)
* Initial work * More or less finished with first draft * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Fix randomly initialized weights * Update src/transformers/modeling_utils.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Address review comments * Rename DeepSpeed folder to temporarily fix the test issue? * Revert to try if Accelerate fix works * Use latest Accelerate release * Quality and fixes * Style * Quality * Add doc * Test + fix * More blocks Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -94,6 +94,8 @@ if is_torch_available():
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AdaptiveEmbedding,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
PreTrainedModel,
|
||||
@@ -2595,6 +2597,22 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# functionality to load models directly on gpu, this test can be rewritten to use torch's
|
||||
# cuda memory tracking and then we should be able to do a much more precise test.
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@slow
|
||||
def test_model_parallelism_gpt2(self):
|
||||
device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
|
||||
for i in range(12):
|
||||
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
inputs = tokenizer("Hello, my name is", return_tensors="pt")
|
||||
output = model.generate(inputs["input_ids"].to(0))
|
||||
|
||||
text_output = tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
||||
Reference in New Issue
Block a user