[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)
This commit is contained in:
@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
|
||||
config_class = BartConfig
|
||||
base_model_prefix = "model"
|
||||
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
encoder_outputs, decoder_cached_states = past, None
|
||||
else:
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
|
||||
Reference in New Issue
Block a user