typo (#6959)
there is no var `decoder_input_ids`, but there is `input_ids` for decoder :)
This commit is contained in:
@@ -411,7 +411,7 @@ class GenerationMixin:
|
|||||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
# create empty decoder_input_ids
|
# create empty decoder input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user