[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)

This commit is contained in:
Sam Shleifer
2020-03-26 18:41:19 -04:00
committed by GitHub
parent 39371ee454
commit 1a5aefc95c
3 changed files with 29 additions and 15 deletions

View File

@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
def _init_weights(self, module):
std = self.config.init_std
@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs, decoder_cached_states = past, None
else:
encoder_outputs, decoder_cached_states = past
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,