TF: BART compatible with XLA generation (#17479)
* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
@@ -295,7 +295,7 @@ class TFGPT2ModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
|
||||
config.eos_token_id = None
|
||||
config.eos_token_id = None # Generate until max length
|
||||
config.max_length = 10
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user