best current version and make style

This commit is contained in:
patrickvonplaten
2020-03-06 22:19:01 +01:00
committed by Patrick von Platen
parent c62444da39
commit 2acfe63964
4 changed files with 45 additions and 36 deletions

View File

@@ -79,8 +79,7 @@ class T5Config(PretrainedConfig):
**kwargs
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
**kwargs,
is_encoder_decoder=is_encoder_decoder, **kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions

View File

@@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
attention_mask.shape, encoder_inputs.shape
)
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):

View File

@@ -19,8 +19,6 @@ import logging
import os
import typing
import ipdb
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -829,7 +827,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
@@ -846,9 +846,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
eos_token_id,
# bos_token_id,
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
dtype=torch.long,
device=next(self.parameters()).device,
)
@@ -919,7 +917,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
eos_token_ids,
batch_size,
encoder_inputs,
attention_mask
attention_mask,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@@ -930,7 +928,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@@ -948,16 +948,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
for batch_idx in range(batch_size):
next_token_logits[
batch_idx, banned_tokens[batch_idx]
] = -float('inf')
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[
:, eos_token_id
] = -float('inf')
next_token_logits[:, eos_token_id] = -float("inf")
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -995,7 +991,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# extend attention_mask for new generated input
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
)
cur_len = cur_len + 1
@@ -1041,8 +1039,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]
# scores for each sentence in the beam
@@ -1060,7 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
@@ -1084,17 +1084,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
scores[:, eos_token_id] = -float('inf')
scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
banned_batch_tokens = calc_banned_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float('inf')
scores[i, banned_tokens] = -float("inf")
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size))
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)
if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
@@ -1203,7 +1207,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# extend attention_mask for new generated input
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
)
# update current length
cur_len = cur_len + 1
@@ -1278,7 +1284,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float('inf')
scores[:, all_but_token_ids_mask] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):

View File

@@ -104,7 +104,9 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
# TODO(SS): fix the below in a separate PR
@@ -472,7 +474,7 @@ class BartModelIntegrationTest(unittest.TestCase):
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True
early_stopping=True,
)
decoded = [