Support T5 Generation (#3228)

* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Patrick von Platen
2020-03-19 23:18:23 +01:00
committed by GitHub
parent 656e1386a2
commit bbf26c4e61
16 changed files with 449 additions and 280 deletions

View File

@@ -474,6 +474,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
@@ -586,7 +587,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`, `TFXLMWithLMHeadModel`)"
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
)
max_length = max_length if max_length is not None else self.config.max_length
@@ -608,6 +609,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
@@ -634,6 +636,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
@@ -703,6 +708,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
attention_mask, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# create empty decoder_input_ids
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
cur_len = 1
else:
encoder_outputs = None
cur_len = shape_list(input_ids)[-1]
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
@@ -716,13 +740,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
else:
@@ -737,10 +764,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
@@ -758,10 +788,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
vocab_size,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example without beam search (num_beams == 1).
@@ -772,7 +805,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
unfinished_sents = tf.ones_like(input_ids[:, 0])
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
past = None
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
@@ -859,6 +892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if tf.math.reduce_max(unfinished_sents) == 0:
break
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded
@@ -896,13 +935,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example with beam search.
@@ -923,8 +965,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states
past = None
past = encoder_outputs
# done sentences
done = [False for _ in range(batch_size)]
@@ -1088,9 +1131,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
# re-order internal states
if past:
if past is not None:
past = self._reorder_cache(past, beam_idx)
if self.config.is_encoder_decoder is False:
attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
# update current length
cur_len = cur_len + 1