Add TFBartForConditionalGeneration (#5411)
* half done * doc improvement * Cp test file * brokedn * broken test * undo some mess * ckpt * borked * Halfway * 6 passing * boom boom * Much progress but still 6 * boom boom * merged master * 10 passing * boom boom * Style * no t5 changes * 13 passing * Integration test failing, but not gibberish * Frustrated * Merged master * 4 fail * 4 fail * fix return_dict * boom boom * Still only 4 * prepare method * prepare method * before delete classif * Skip tests to avoid adding boilerplate * boom boom * fast tests passing * style * boom boom * Switch to supporting many input types * remove FIXMENORM * working * Fixed past_key_values/decoder_cached_states confusion * new broken test * Fix attention mask kwarg name * undo accidental * Style and reviewers * style * Docs and common tests * Cleaner assert messages * copy docs * style issues * Sphinx fix * Simplify caching logic * test does not require torch * copy _NoLayerEmbedTokens * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update tests/test_modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Line length and dont document None * Add pipeline test coverage * assert msg * At parity * Assert messages * mark slow * Update compile test * back in init * Merge master * Fix tests Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -640,6 +640,10 @@ class TFGenerationMixin:
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
# calculate log softmax score
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
@@ -890,6 +894,13 @@ class TFGenerationMixin:
|
||||
def _reorder_cache(past, beam_idx):
|
||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
the generate method.
|
||||
"""
|
||||
return logits
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
# create logit penalties for already seen input_ids
|
||||
|
||||
Reference in New Issue
Block a user