TF generate refactor - Sample (#15793)

* Add TF logits wrappers 

* Add sample method

* add tests for TF logit wrappers

* TF generate sample tests now run on CPU

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-03-02 16:13:54 +00:00
committed by GitHub
parent 96ae92be8c
commit baab5e7cdf
13 changed files with 652 additions and 332 deletions

View File

@@ -947,7 +947,7 @@ class TFModelTesterMixin:
if config.bos_token_id is None:
# if bos token id is not defined model needs input_ids
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True))