Support T5 Generation (#3228)

* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Patrick von Platen
2020-03-19 23:18:23 +01:00
committed by GitHub
parent 656e1386a2
commit bbf26c4e61
16 changed files with 449 additions and 280 deletions

View File

@@ -806,10 +806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
# TODO: think about how to make this cleaner
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
@@ -912,20 +909,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
encoder_inputs = input_ids
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# create empty decoder_input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id, # TODO: see whether this is the best result
decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
cur_len = 1
else:
encoder_inputs = None
encoder_outputs = None
cur_len = input_ids.shape[-1]
if num_beams > 1:
@@ -944,12 +948,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_inputs=encoder_inputs,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
else:
@@ -964,10 +969,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
encoder_inputs=encoder_inputs,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
@@ -985,10 +992,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
encoder_inputs,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example without beam search (num_beams == 1).
@@ -998,11 +1007,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@@ -1099,12 +1107,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_inputs,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example with beam search.
@@ -1125,15 +1134,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = None
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
@@ -1152,8 +1159,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False:
# TODO: maybe give better naming
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# TODO (PVP) still a bit hacky here - there might be a better solutino
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
@@ -1278,7 +1285,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
# re-order internal states
if past:
if past is not None:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
@@ -1345,8 +1352,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
return decoded[:, 1:]
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.