Support T5 Generation (#3228)
* fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656e1386a2
commit
bbf26c4e61
@@ -474,6 +474,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
no_repeat_ngram_size=None,
|
||||
num_return_sequences=None,
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
@@ -586,7 +587,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError(
|
||||
"You tried to generate sequences with a model that does not have a LM Head."
|
||||
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`, `TFXLMWithLMHeadModel`)"
|
||||
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
|
||||
)
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
@@ -608,6 +609,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||
@@ -634,6 +636,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
assert (eos_token_ids is None) or (
|
||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert (
|
||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||
assert (
|
||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||
@@ -703,6 +708,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
attention_mask, (effective_batch_size * num_beams, input_ids_len)
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
|
||||
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
||||
|
||||
# create empty decoder_input_ids
|
||||
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
|
||||
cur_len = 1
|
||||
|
||||
else:
|
||||
encoder_outputs = None
|
||||
cur_len = shape_list(input_ids)[-1]
|
||||
|
||||
if num_beams > 1:
|
||||
output = self._generate_beam_search(
|
||||
input_ids,
|
||||
@@ -716,13 +740,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
@@ -737,10 +764,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
@@ -758,10 +788,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
@@ -772,7 +805,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
||||
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
||||
|
||||
past = None
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
@@ -859,6 +892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if tf.math.reduce_max(unfinished_sents) == 0:
|
||||
break
|
||||
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = tf.concat(
|
||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
||||
)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||
@@ -896,13 +935,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
@@ -923,8 +965,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
||||
|
||||
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
|
||||
|
||||
# cache compute states
|
||||
past = None
|
||||
past = encoder_outputs
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
@@ -1088,9 +1131,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
||||
# re-order internal states
|
||||
if past:
|
||||
if past is not None:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = tf.concat(
|
||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
||||
)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user