fix conflicts

This commit is contained in:
Patrick von Platen
2020-03-06 11:28:10 +01:00
parent d6de6423ba
commit d8e2b3c547
4 changed files with 178 additions and 76 deletions

View File

@@ -16,6 +16,7 @@
import logging
import math
import random
import ipdb
from typing import Dict, List, Optional, Tuple
import torch
@@ -943,7 +944,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@@ -954,7 +955,21 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
# "decoder_attention_mask": decoder_attention_mask,
}
@staticmethod
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
encoder_outputs, decoder_cached_states = past
input_ids = encoder_inputs
return {
"input_ids": input_ids, # ignored after first pass
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
}
@staticmethod
@@ -979,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return self.lm_head
@torch.no_grad()
def generate(
def generate_1(
self,
input_ids,
attention_mask=None,
@@ -1099,7 +1114,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation(
model_inputs = self.prepare_inputs_for_generation_1(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)