Add TFBartForConditionalGeneration (#5411)

* half done

* doc improvement

* Cp test file

* brokedn

* broken test

* undo some mess

* ckpt

* borked

* Halfway

* 6 passing

* boom boom

* Much progress but still 6

* boom boom

* merged master

* 10 passing

* boom boom

* Style

* no t5 changes

* 13 passing

* Integration test failing, but not gibberish

* Frustrated

* Merged master

* 4 fail

* 4 fail

* fix return_dict

* boom boom

* Still only 4

* prepare method

* prepare method

* before delete classif

* Skip tests to avoid adding boilerplate

* boom boom

* fast tests passing

* style

* boom boom

* Switch to supporting many input types

* remove FIXMENORM

* working

* Fixed past_key_values/decoder_cached_states confusion

* new broken test

* Fix attention mask kwarg name

* undo accidental

* Style and reviewers

* style

* Docs and common tests

* Cleaner assert messages

* copy docs

* style issues

* Sphinx fix

* Simplify caching logic

* test does not require torch

* copy _NoLayerEmbedTokens

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update tests/test_modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Line length and dont document None

* Add pipeline test coverage

* assert msg

* At parity

* Assert messages

* mark slow

* Update compile test

* back in init

* Merge master

* Fix tests

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sam Shleifer
2020-10-21 07:10:16 -04:00
committed by GitHub
parent 5cd9e2cba1
commit 829842159e
20 changed files with 1731 additions and 116 deletions

View File

@@ -640,6 +640,10 @@ class TFGenerationMixin:
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
if self.config.is_encoder_decoder and do_sample is False:
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
@@ -890,6 +894,13 @@ class TFGenerationMixin:
def _reorder_cache(past, beam_idx):
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def adjust_logits_during_generation(self, logits, **kwargs):
"""
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method.
"""
return logits
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids