best current version and make style

This commit is contained in:
patrickvonplaten
2020-03-06 22:19:01 +01:00
committed by Patrick von Platen
parent c62444da39
commit 2acfe63964
4 changed files with 45 additions and 36 deletions

View File

@@ -79,8 +79,7 @@ class T5Config(PretrainedConfig):
**kwargs **kwargs
): ):
super().__init__( super().__init__(
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder, **kwargs,
**kwargs,
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions

View File

@@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask): def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape) assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
attention_mask.shape, encoder_inputs.shape
)
if past is None: # first step if past is None: # first step
encoder_outputs, decoder_cached_states = None, None encoder_outputs, decoder_cached_states = None, None
else: else:
@@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states, "decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask "attention_mask": attention_mask,
} }
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_scores_for_generation(self, scores, cur_len, max_length):

View File

@@ -19,8 +19,6 @@ import logging
import os import os
import typing import typing
import ipdb
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
@@ -829,7 +827,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_return_sequences > 1 or num_beams > 1: if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1] input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
input_ids = input_ids.contiguous().view( input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len effective_batch_size * num_beams, input_ids_len
@@ -846,9 +846,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
eos_token_id, eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
# bos_token_id,
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
@@ -919,7 +917,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
eos_token_ids, eos_token_ids,
batch_size, batch_size,
encoder_inputs, encoder_inputs,
attention_mask attention_mask,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """ Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
@@ -930,7 +928,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
past = None past = None
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask) model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
@@ -948,16 +948,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
next_token_logits[ next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
batch_idx, banned_tokens[batch_idx]
] = -float('inf')
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids: for eos_token_id in eos_token_ids:
next_token_logits[ next_token_logits[:, eos_token_id] = -float("inf")
:, eos_token_id
] = -float('inf')
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
@@ -995,7 +991,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# extend attention_mask for new generated input # extend attention_mask for new generated input
if self.config.is_encoder_decoder is False: if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1) attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
)
cur_len = cur_len + 1 cur_len = cur_len + 1
@@ -1041,8 +1039,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size) BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size) for _ in range(batch_size)
] ]
# scores for each sentence in the beam # scores for each sentence in the beam
@@ -1060,7 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask) model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
@@ -1084,17 +1084,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids: for eos_token_id in eos_token_ids:
scores[:, eos_token_id] = -float('inf') scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len) banned_batch_tokens = calc_banned_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
)
for i, banned_tokens in enumerate(banned_batch_tokens): for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float('inf') scores[i, banned_tokens] = -float("inf")
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size)) assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)
if do_sample: if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
@@ -1203,7 +1207,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# extend attention_mask for new generated input # extend attention_mask for new generated input
if self.config.is_encoder_decoder is False: if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1) attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
)
# update current length # update current length
cur_len = cur_len + 1 cur_len = cur_len + 1
@@ -1278,7 +1284,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float('inf') scores[:, all_but_token_ids_mask] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):

View File

@@ -104,7 +104,9 @@ def prepare_bart_inputs_dict(
@require_torch @require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase): class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else () all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
# TODO(SS): fix the below in a separate PR # TODO(SS): fix the below in a separate PR
@@ -472,7 +474,7 @@ class BartModelIntegrationTest(unittest.TestCase):
min_length=min_length + 1, min_length=min_length + 1,
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
do_sample=False, do_sample=False,
early_stopping=True early_stopping=True,
) )
decoded = [ decoded = [