fix conflicts
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import ipdb
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -943,7 +944,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
|
||||
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask):
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
@@ -954,7 +955,21 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
# "decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
|
||||
input_ids = encoder_inputs
|
||||
return {
|
||||
"input_ids": input_ids, # ignored after first pass
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -979,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return self.lm_head
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
def generate_1(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
@@ -1099,7 +1114,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
||||
for step in range(max_length + 1):
|
||||
decoder_input_ids = prev_output_tokens.clone()
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
model_inputs = self.prepare_inputs_for_generation_1(
|
||||
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
|
||||
@@ -787,7 +787,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
pad_token_id = eos_token_ids[0]
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = input_ids.shape[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
# set effective batch size and effective batch multiplier according to do_sample
|
||||
@@ -806,6 +805,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
effective_batch_size * num_beams, input_ids_len
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
# TODO (PVP): check eos_token_id
|
||||
# TODO (PVP): probably not the best way to check whether model is encoder decoder
|
||||
is_encoder_decoder = (
|
||||
hasattr(self, "model")
|
||||
and hasattr(self.model, "decoder")
|
||||
and hasattr(self.model, "encoder")
|
||||
)
|
||||
if is_encoder_decoder:
|
||||
eos_token_id = eos_token_ids[0]
|
||||
assert (
|
||||
bos_token_id is not None
|
||||
), "Encoder Decoder Models need to have a bos_token_id"
|
||||
assert (
|
||||
eos_token_id is not None
|
||||
), "Encoder Decoder Models need to have a eos_token_id"
|
||||
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
# eos_token_id,
|
||||
bos_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
cur_len = 0
|
||||
self.model.decoder.generation_mode = True
|
||||
else:
|
||||
encoder_inputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
if num_beams > 1:
|
||||
output = self._generate_beam_search(
|
||||
input_ids,
|
||||
@@ -823,6 +852,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_inputs,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -837,6 +867,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -854,6 +885,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
batch_size,
|
||||
encoder_inputs,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -864,7 +896,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
past = None
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, encoder_inputs=encoder_inputs
|
||||
)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
@@ -943,9 +977,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_inputs,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
is_encoder_decoder = (
|
||||
hasattr(self, "model")
|
||||
and hasattr(self.model, "decoder")
|
||||
and hasattr(self.model, "encoder")
|
||||
)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
@@ -966,7 +1006,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, encoder_inputs=encoder_inputs
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
@@ -1012,6 +1054,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
if is_encoder_decoder: # TODO(PVP) to be refactored later
|
||||
import math
|
||||
# scores[scores != scores] = -math.inf # block nans
|
||||
# scores[:, pad_token_id] = -math.inf
|
||||
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
||||
# if cur_len == 0: # Force BOS to be chosen
|
||||
# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here
|
||||
# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len
|
||||
# scores[:, eos_token_ids[0]] = -math.inf
|
||||
# elif cur_len == max_length: # FORCE EOS to be chosen
|
||||
if cur_len == max_length: # FORCE EOS to be chosen
|
||||
scores[:, :eos_token_ids[0]] = -math.inf
|
||||
scores[:, eos_token_ids[0] + 1 :] = -math.inf
|
||||
|
||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
@@ -1137,7 +1194,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# shorter batches are filled with pad_token
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
@@ -1150,7 +1207,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
return decoded
|
||||
return decoded[:, 1:]
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
||||
Reference in New Issue
Block a user