Fixed beam search generation for GPT2 and T5 (#9219)
This commit is contained in:
@@ -156,7 +156,7 @@ class GenerationMixin:
|
||||
if is_encoder_decoder:
|
||||
assert encoder_outputs is not None
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
||||
0, expanded_return_idx
|
||||
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||
)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
return input_ids, model_kwargs
|
||||
@@ -226,7 +226,7 @@ class GenerationMixin:
|
||||
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
||||
subclasses of :class:`~transformers.PreTrainedModel`.
|
||||
"""
|
||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
||||
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
||||
|
||||
Reference in New Issue
Block a user