Fix beam search when using model parallel (#24969)
* Fix GPTNeoX beam search when using parallelize * Fix beam search idx device when using model parallel * remove onnx related stuff Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin * fix: add right item to _no_split_modules of MegaPreTrainedModel * fix: add num_beams within parallelized beam_search test Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -15,13 +15,14 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
@@ -1017,6 +1018,27 @@ class GenerationTesterMixin:
|
||||
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
|
||||
)
|
||||
|
||||
@require_accelerate
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallel_beam_search(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).eval()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
|
||||
|
||||
new_model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=2,
|
||||
)
|
||||
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
@@ -2482,34 +2482,6 @@ class ModelTesterMixin:
|
||||
for value_, parallel_value_ in zip(value, parallel_value):
|
||||
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallel_beam_search(self):
|
||||
if not self.test_model_parallel:
|
||||
return
|
||||
|
||||
all_generative_and_parallelizable_model_classes = tuple(
|
||||
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in all_generative_and_parallelizable_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
def cast_to_device(dictionary, device):
|
||||
output = {}
|
||||
for k, v in dictionary.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
output[k] = v.to(device)
|
||||
else:
|
||||
output[k] = v
|
||||
|
||||
return output
|
||||
|
||||
model.parallelize()
|
||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||
|
||||
def check_device_map_is_respected(self, model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
|
||||
Reference in New Issue
Block a user