Fix beam search when using model parallel (#24969)

* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Dong-Yong Lee
2023-09-15 00:00:52 +09:00
committed by GitHub
parent 0dd06c3f78
commit 8881f38a4f
53 changed files with 191 additions and 95 deletions

View File

@@ -15,13 +15,14 @@
import inspect
import tempfile
import unittest
import warnings
import numpy as np
from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
@@ -1017,6 +1018,27 @@ class GenerationTesterMixin:
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
)
@require_accelerate
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if model_class._no_split_modules is None:
continue
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).eval()
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
new_model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=2,
)
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

View File

@@ -2482,34 +2482,6 @@ class ModelTesterMixin:
for value_, parallel_value_ in zip(value, parallel_value):
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
if not self.test_model_parallel:
return
all_generative_and_parallelizable_model_classes = tuple(
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in all_generative_and_parallelizable_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
def cast_to_device(dictionary, device):
output = {}
for k, v in dictionary.items():
if isinstance(v, torch.Tensor):
output[k] = v.to(device)
else:
output[k] = v
return output
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map