Fix beam search when using model parallel (#24969)

* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Dong-Yong Lee
2023-09-15 00:00:52 +09:00
committed by GitHub
parent 0dd06c3f78
commit 8881f38a4f
53 changed files with 191 additions and 95 deletions

View File

@@ -15,13 +15,14 @@
import inspect
import tempfile
import unittest
import warnings
import numpy as np
from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
@@ -1017,6 +1018,27 @@ class GenerationTesterMixin:
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
)
@require_accelerate
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if model_class._no_split_modules is None:
continue
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).eval()
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
new_model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=2,
)
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()