Fixed beam search generation for GPT2 and T5 (#9219)
This commit is contained in:
@@ -156,7 +156,7 @@ class GenerationMixin:
|
|||||||
if is_encoder_decoder:
|
if is_encoder_decoder:
|
||||||
assert encoder_outputs is not None
|
assert encoder_outputs is not None
|
||||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
||||||
0, expanded_return_idx
|
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||||
)
|
)
|
||||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
@@ -226,7 +226,7 @@ class GenerationMixin:
|
|||||||
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
||||||
subclasses of :class:`~transformers.PreTrainedModel`.
|
subclasses of :class:`~transformers.PreTrainedModel`.
|
||||||
"""
|
"""
|
||||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
|
||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
||||||
|
|||||||
@@ -1166,6 +1166,34 @@ class ModelTesterMixin:
|
|||||||
for value_, parallel_value_ in zip(value, parallel_value):
|
for value_, parallel_value_ in zip(value, parallel_value):
|
||||||
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
|
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_model_parallel_beam_search(self):
|
||||||
|
if not self.test_model_parallel:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_generative_and_parallelizable_model_classes = tuple(
|
||||||
|
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in all_generative_and_parallelizable_model_classes:
|
||||||
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
def cast_to_device(dictionary, device):
|
||||||
|
output = {}
|
||||||
|
for k, v in dictionary.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
output[k] = v.to(device)
|
||||||
|
else:
|
||||||
|
output[k] = v
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
model.parallelize()
|
||||||
|
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user