Use Accelerate in from_pretrained for big model inference (#17341)

* Initial work

* More or less finished with first draft

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix randomly initialized weights

* Update src/transformers/modeling_utils.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Address review comments

* Rename DeepSpeed folder to temporarily fix the test issue?

* Revert to try if Accelerate fix works

* Use latest Accelerate release

* Quality and fixes

* Style

* Quality

* Add doc

* Test + fix

* More blocks

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2022-05-23 14:32:21 -04:00
committed by GitHub
parent 2e7e4280aa
commit 56f50590d5
9 changed files with 270 additions and 41 deletions

View File

@@ -94,6 +94,8 @@ if is_torch_available():
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
AdaptiveEmbedding,
AutoModelForCausalLM,
AutoTokenizer,
BertConfig,
BertModel,
PreTrainedModel,
@@ -2595,6 +2597,22 @@ class ModelUtilsTest(TestCasePlus):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
@require_torch_multi_gpu
@slow
def test_model_parallelism_gpt2(self):
device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
for i in range(12):
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello, my name is", return_tensors="pt")
output = model.generate(inputs["input_ids"].to(0))
text_output = tokenizer.decode(output[0].tolist())
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()