finalize generation merge

This commit is contained in:
Patrick von Platen
2020-03-11 11:53:36 +01:00
parent 1ba21f96ca
commit a332cc9f7f
4 changed files with 10 additions and 13 deletions

View File

@@ -840,14 +840,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
eos_token_id = eos_token_ids[0]
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,