finalize generation merge
This commit is contained in:
@@ -840,14 +840,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
eos_token_id = eos_token_ids[0]
|
||||
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
|
||||
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||
bos_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
|
||||
Reference in New Issue
Block a user