TF: BART compatible with XLA generation (#17479)

* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
Joao Gante
2022-06-20 11:07:46 +01:00
committed by GitHub
parent 6589e510fa
commit 132402d752
18 changed files with 421 additions and 86 deletions

View File

@@ -295,7 +295,7 @@ class TFGPT2ModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)