better naming

This commit is contained in:
Patrick von Platen
2020-03-05 15:48:00 +01:00
parent ff648221bd
commit 7cba11fb9b
3 changed files with 6 additions and 7 deletions

View File

@@ -943,7 +943,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask):
def prepare_inputs_for_generation_bart(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return self.lm_head
@torch.no_grad()
def generate_1(
def generate_bart(
self,
input_ids,
attention_mask=None,
@@ -1113,7 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation_1(
model_inputs = self.prepare_inputs_for_generation_bart(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)