best current version and make style
This commit is contained in:
committed by
Patrick von Platen
parent
c62444da39
commit
2acfe63964
@@ -19,8 +19,6 @@ import logging
|
||||
import os
|
||||
import typing
|
||||
|
||||
import ipdb
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
@@ -829,7 +827,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if num_return_sequences > 1 or num_beams > 1:
|
||||
input_ids_len = input_ids.shape[-1]
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||
attention_mask = attention_mask.unsqueeze(1).expand(
|
||||
batch_size, effective_batch_mult * num_beams, input_ids_len
|
||||
)
|
||||
|
||||
input_ids = input_ids.contiguous().view(
|
||||
effective_batch_size * num_beams, input_ids_len
|
||||
@@ -846,9 +846,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
eos_token_id,
|
||||
# bos_token_id,
|
||||
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
||||
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
@@ -919,7 +917,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
eos_token_ids,
|
||||
batch_size,
|
||||
encoder_inputs,
|
||||
attention_mask
|
||||
attention_mask,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -930,7 +928,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
past = None
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
@@ -948,16 +948,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||
for batch_idx in range(batch_size):
|
||||
next_token_logits[
|
||||
batch_idx, banned_tokens[batch_idx]
|
||||
] = -float('inf')
|
||||
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[
|
||||
:, eos_token_id
|
||||
] = -float('inf')
|
||||
next_token_logits[:, eos_token_id] = -float("inf")
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -995,7 +991,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||
)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
@@ -1041,8 +1039,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
||||
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
@@ -1060,7 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
@@ -1084,17 +1084,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
scores[:, eos_token_id] = -float('inf')
|
||||
scores[:, eos_token_id] = -float("inf")
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
num_batch_hypotheses = batch_size * num_beams
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
|
||||
banned_batch_tokens = calc_banned_tokens(
|
||||
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
||||
)
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float('inf')
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size))
|
||||
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
||||
scores.shape, (batch_size * num_beams, vocab_size)
|
||||
)
|
||||
|
||||
if do_sample:
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
@@ -1203,7 +1207,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||
)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
@@ -1278,7 +1284,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float('inf')
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
@@ -1311,7 +1317,7 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len)
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = cur_len + 1 - no_repeat_ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx: cur_len].tolist())
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
|
||||
Reference in New Issue
Block a user