Support T5 Generation (#3228)
* fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656e1386a2
commit
bbf26c4e61
@@ -806,10 +806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
# TODO: think about how to make this cleaner
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
|
||||
)
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||
@@ -912,20 +909,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
attention_mask = attention_mask.contiguous().view(
|
||||
effective_batch_size * num_beams, input_ids_len
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
|
||||
encoder_inputs = input_ids
|
||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
||||
|
||||
# create empty decoder_input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
decoder_start_token_id, # TODO: see whether this is the best result
|
||||
decoder_start_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
cur_len = 1
|
||||
|
||||
else:
|
||||
encoder_inputs = None
|
||||
encoder_outputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
if num_beams > 1:
|
||||
@@ -944,12 +948,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_inputs=encoder_inputs,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
@@ -964,10 +969,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
encoder_inputs=encoder_inputs,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
@@ -985,10 +992,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
encoder_inputs,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
@@ -998,11 +1007,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
past = None
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||
)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
@@ -1099,12 +1107,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_inputs,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
@@ -1125,15 +1134,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# cache compute states
|
||||
past = None
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
|
||||
)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
@@ -1152,8 +1159,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
# TODO: maybe give better naming
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solutino
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
@@ -1278,7 +1285,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
if past is not None:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
@@ -1345,8 +1352,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return decoded[:, 1:]
|
||||
return decoded
|
||||
|
||||
# force one of token_ids to be generated by setting prob of all other tokens to 0.
|
||||
|
||||
Reference in New Issue
Block a user