best current version and make style
This commit is contained in:
committed by
Patrick von Platen
parent
c62444da39
commit
2acfe63964
@@ -79,8 +79,7 @@ class T5Config(PretrainedConfig):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder, **kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_positions = n_positions
|
self.n_positions = n_positions
|
||||||
|
|||||||
@@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
|
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
|
||||||
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
|
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
|
||||||
|
attention_mask.shape, encoder_inputs.shape
|
||||||
|
)
|
||||||
if past is None: # first step
|
if past is None: # first step
|
||||||
encoder_outputs, decoder_cached_states = None, None
|
encoder_outputs, decoder_cached_states = None, None
|
||||||
else:
|
else:
|
||||||
@@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"decoder_cached_states": decoder_cached_states,
|
"decoder_cached_states": decoder_cached_states,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
@@ -829,7 +827,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if num_return_sequences > 1 or num_beams > 1:
|
if num_return_sequences > 1 or num_beams > 1:
|
||||||
input_ids_len = input_ids.shape[-1]
|
input_ids_len = input_ids.shape[-1]
|
||||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||||
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
attention_mask = attention_mask.unsqueeze(1).expand(
|
||||||
|
batch_size, effective_batch_mult * num_beams, input_ids_len
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = input_ids.contiguous().view(
|
input_ids = input_ids.contiguous().view(
|
||||||
effective_batch_size * num_beams, input_ids_len
|
effective_batch_size * num_beams, input_ids_len
|
||||||
@@ -846,9 +846,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
encoder_inputs = input_ids
|
encoder_inputs = input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
eos_token_id,
|
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||||
# bos_token_id,
|
|
||||||
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
@@ -919,7 +917,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
encoder_inputs,
|
encoder_inputs,
|
||||||
attention_mask
|
attention_mask,
|
||||||
):
|
):
|
||||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
@@ -930,7 +928,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
past = None
|
past = None
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
|
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
@@ -948,16 +948,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
next_token_logits[
|
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
||||||
batch_idx, banned_tokens[batch_idx]
|
|
||||||
] = -float('inf')
|
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
for eos_token_id in eos_token_ids:
|
||||||
next_token_logits[
|
next_token_logits[:, eos_token_id] = -float("inf")
|
||||||
:, eos_token_id
|
|
||||||
] = -float('inf')
|
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
@@ -995,7 +991,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# extend attention_mask for new generated input
|
# extend attention_mask for new generated input
|
||||||
if self.config.is_encoder_decoder is False:
|
if self.config.is_encoder_decoder is False:
|
||||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
attention_mask = torch.cat(
|
||||||
|
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
@@ -1041,8 +1039,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# generated hypotheses
|
# generated hypotheses
|
||||||
generated_hyps = [
|
generated_hyps = [
|
||||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
|
||||||
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
for _ in range(batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
# scores for each sentence in the beam
|
# scores for each sentence in the beam
|
||||||
@@ -1060,7 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
|
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||||
|
)
|
||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -1084,17 +1084,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
for eos_token_id in eos_token_ids:
|
||||||
scores[:, eos_token_id] = -float('inf')
|
scores[:, eos_token_id] = -float("inf")
|
||||||
|
|
||||||
if no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size > 0:
|
||||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
num_batch_hypotheses = batch_size * num_beams
|
num_batch_hypotheses = batch_size * num_beams
|
||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
|
banned_batch_tokens = calc_banned_tokens(
|
||||||
|
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
||||||
|
)
|
||||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||||
scores[i, banned_tokens] = -float('inf')
|
scores[i, banned_tokens] = -float("inf")
|
||||||
|
|
||||||
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size))
|
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
||||||
|
scores.shape, (batch_size * num_beams, vocab_size)
|
||||||
|
)
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -1203,7 +1207,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# extend attention_mask for new generated input
|
# extend attention_mask for new generated input
|
||||||
if self.config.is_encoder_decoder is False:
|
if self.config.is_encoder_decoder is False:
|
||||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
attention_mask = torch.cat(
|
||||||
|
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
# update current length
|
# update current length
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
@@ -1278,7 +1284,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||||
scores[:, all_but_token_ids_mask] = -float('inf')
|
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
@@ -1311,7 +1317,7 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len)
|
|||||||
def _get_generated_ngrams(hypo_idx):
|
def _get_generated_ngrams(hypo_idx):
|
||||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||||
start_idx = cur_len + 1 - no_repeat_ngram_size
|
start_idx = cur_len + 1 - no_repeat_ngram_size
|
||||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx: cur_len].tolist())
|
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
||||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||||
|
|
||||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||||
|
|||||||
@@ -104,7 +104,9 @@ def prepare_bart_inputs_dict(
|
|||||||
@require_torch
|
@require_torch
|
||||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
# TODO(SS): fix the below in a separate PR
|
# TODO(SS): fix the below in a separate PR
|
||||||
@@ -451,9 +453,9 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
|
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
|
||||||
|
|
||||||
dct = tok.batch_encode_plus(
|
dct = tok.batch_encode_plus(
|
||||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||||
[IRAN_ARTICLE, ARTICLE_SUBWAY],
|
[IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
|
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
|
||||||
max_length=1024,
|
max_length=1024,
|
||||||
pad_to_max_length=True,
|
pad_to_max_length=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -472,17 +474,17 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
min_length=min_length + 1,
|
min_length=min_length + 1,
|
||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
early_stopping=True
|
early_stopping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = [
|
decoded = [
|
||||||
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||||
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
|
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
|
||||||
decoded,
|
decoded,
|
||||||
)
|
)
|
||||||
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
||||||
|
|||||||
Reference in New Issue
Block a user