best current version and make style
This commit is contained in:
committed by
Patrick von Platen
parent
c62444da39
commit
2acfe63964
@@ -79,8 +79,7 @@ class T5Config(PretrainedConfig):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder, **kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_positions = n_positions
|
self.n_positions = n_positions
|
||||||
|
|||||||
@@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
|
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
|
||||||
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
|
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
|
||||||
|
attention_mask.shape, encoder_inputs.shape
|
||||||
|
)
|
||||||
if past is None: # first step
|
if past is None: # first step
|
||||||
encoder_outputs, decoder_cached_states = None, None
|
encoder_outputs, decoder_cached_states = None, None
|
||||||
else:
|
else:
|
||||||
@@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"decoder_cached_states": decoder_cached_states,
|
"decoder_cached_states": decoder_cached_states,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
@@ -829,7 +827,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if num_return_sequences > 1 or num_beams > 1:
|
if num_return_sequences > 1 or num_beams > 1:
|
||||||
input_ids_len = input_ids.shape[-1]
|
input_ids_len = input_ids.shape[-1]
|
||||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||||
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
attention_mask = attention_mask.unsqueeze(1).expand(
|
||||||
|
batch_size, effective_batch_mult * num_beams, input_ids_len
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = input_ids.contiguous().view(
|
input_ids = input_ids.contiguous().view(
|
||||||
effective_batch_size * num_beams, input_ids_len
|
effective_batch_size * num_beams, input_ids_len
|
||||||
@@ -846,9 +846,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
encoder_inputs = input_ids
|
encoder_inputs = input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
eos_token_id,
|
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||||
# bos_token_id,
|
|
||||||
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
@@ -919,7 +917,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
encoder_inputs,
|
encoder_inputs,
|
||||||
attention_mask
|
attention_mask,
|
||||||
):
|
):
|
||||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
@@ -930,7 +928,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
past = None
|
past = None
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
|
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
@@ -948,16 +948,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
next_token_logits[
|
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
||||||
batch_idx, banned_tokens[batch_idx]
|
|
||||||
] = -float('inf')
|
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
for eos_token_id in eos_token_ids:
|
||||||
next_token_logits[
|
next_token_logits[:, eos_token_id] = -float("inf")
|
||||||
:, eos_token_id
|
|
||||||
] = -float('inf')
|
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
@@ -995,7 +991,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# extend attention_mask for new generated input
|
# extend attention_mask for new generated input
|
||||||
if self.config.is_encoder_decoder is False:
|
if self.config.is_encoder_decoder is False:
|
||||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
attention_mask = torch.cat(
|
||||||
|
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
@@ -1041,8 +1039,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# generated hypotheses
|
# generated hypotheses
|
||||||
generated_hyps = [
|
generated_hyps = [
|
||||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
|
||||||
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
for _ in range(batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
# scores for each sentence in the beam
|
# scores for each sentence in the beam
|
||||||
@@ -1060,7 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
|
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||||
|
)
|
||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -1084,17 +1084,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
for eos_token_id in eos_token_ids:
|
||||||
scores[:, eos_token_id] = -float('inf')
|
scores[:, eos_token_id] = -float("inf")
|
||||||
|
|
||||||
if no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size > 0:
|
||||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
num_batch_hypotheses = batch_size * num_beams
|
num_batch_hypotheses = batch_size * num_beams
|
||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
|
banned_batch_tokens = calc_banned_tokens(
|
||||||
|
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
||||||
|
)
|
||||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||||
scores[i, banned_tokens] = -float('inf')
|
scores[i, banned_tokens] = -float("inf")
|
||||||
|
|
||||||
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size))
|
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
||||||
|
scores.shape, (batch_size * num_beams, vocab_size)
|
||||||
|
)
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -1203,7 +1207,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# extend attention_mask for new generated input
|
# extend attention_mask for new generated input
|
||||||
if self.config.is_encoder_decoder is False:
|
if self.config.is_encoder_decoder is False:
|
||||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
attention_mask = torch.cat(
|
||||||
|
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
# update current length
|
# update current length
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
@@ -1278,7 +1284,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||||
scores[:, all_but_token_ids_mask] = -float('inf')
|
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
|
|||||||
@@ -104,7 +104,9 @@ def prepare_bart_inputs_dict(
|
|||||||
@require_torch
|
@require_torch
|
||||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
# TODO(SS): fix the below in a separate PR
|
# TODO(SS): fix the below in a separate PR
|
||||||
@@ -472,7 +474,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
min_length=min_length + 1,
|
min_length=min_length + 1,
|
||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
early_stopping=True
|
early_stopping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = [
|
decoded = [
|
||||||
|
|||||||
Reference in New Issue
Block a user