fix conflicts

This commit is contained in:
Patrick von Platen
2020-03-06 11:28:10 +01:00
parent d6de6423ba
commit d8e2b3c547
4 changed files with 178 additions and 76 deletions

View File

@@ -16,6 +16,7 @@
import logging
import math
import random
import ipdb
from typing import Dict, List, Optional, Tuple
import torch
@@ -943,7 +944,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@@ -954,7 +955,21 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
# "decoder_attention_mask": decoder_attention_mask,
}
@staticmethod
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
encoder_outputs, decoder_cached_states = past
input_ids = encoder_inputs
return {
"input_ids": input_ids, # ignored after first pass
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
}
@staticmethod
@@ -979,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return self.lm_head
@torch.no_grad()
def generate(
def generate_1(
self,
input_ids,
attention_mask=None,
@@ -1099,7 +1114,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation(
model_inputs = self.prepare_inputs_for_generation_1(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)

View File

@@ -787,7 +787,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id = eos_token_ids[0]
# current position and vocab size
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
# set effective batch size and effective batch multiplier according to do_sample
@@ -806,6 +805,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
# TODO (PVP): check eos_token_id
# TODO (PVP): probably not the best way to check whether model is encoder decoder
is_encoder_decoder = (
hasattr(self, "model")
and hasattr(self.model, "decoder")
and hasattr(self.model, "encoder")
)
if is_encoder_decoder:
eos_token_id = eos_token_ids[0]
assert (
bos_token_id is not None
), "Encoder Decoder Models need to have a bos_token_id"
assert (
eos_token_id is not None
), "Encoder Decoder Models need to have a eos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
# eos_token_id,
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
cur_len = 0
self.model.decoder.generation_mode = True
else:
encoder_inputs = None
cur_len = input_ids.shape[-1]
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
@@ -823,6 +852,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
length_penalty,
num_beams,
vocab_size,
encoder_inputs,
)
else:
output = self._generate_no_beam_search(
@@ -837,6 +867,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id,
eos_token_ids,
effective_batch_size,
encoder_inputs,
)
return output
@@ -854,6 +885,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id,
eos_token_ids,
batch_size,
encoder_inputs,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@@ -864,7 +896,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@@ -943,9 +977,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
length_penalty,
num_beams,
vocab_size,
encoder_inputs,
):
""" Generate sequences for each example with beam search.
"""
is_encoder_decoder = (
hasattr(self, "model")
and hasattr(self.model, "decoder")
and hasattr(self.model, "encoder")
)
# generated hypotheses
generated_hyps = [
@@ -966,7 +1006,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
@@ -1012,6 +1054,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
# do greedy beam search
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if is_encoder_decoder: # TODO(PVP) to be refactored later
import math
# scores[scores != scores] = -math.inf # block nans
# scores[:, pad_token_id] = -math.inf
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
# if cur_len == 0: # Force BOS to be chosen
# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here
# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len
# scores[:, eos_token_ids[0]] = -math.inf
# elif cur_len == max_length: # FORCE EOS to be chosen
if cur_len == max_length: # FORCE EOS to be chosen
scores[:, :eos_token_ids[0]] = -math.inf
scores[:, eos_token_ids[0] + 1 :] = -math.inf
assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
@@ -1137,7 +1194,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1)
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
@@ -1150,7 +1207,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded
return decoded[:, 1:]
@staticmethod
def _reorder_cache(past, beam_idx):