fix conflicts
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import ipdb
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -943,7 +944,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
|
||||
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask):
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
@@ -954,7 +955,21 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
# "decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
|
||||
input_ids = encoder_inputs
|
||||
return {
|
||||
"input_ids": input_ids, # ignored after first pass
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -979,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return self.lm_head
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
def generate_1(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
@@ -1099,7 +1114,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
||||
for step in range(max_length + 1):
|
||||
decoder_input_ids = prev_output_tokens.clone()
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
model_inputs = self.prepare_inputs_for_generation_1(
|
||||
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
|
||||
Reference in New Issue
Block a user