TF generate refactor - Greedy Search (#15562)
* TF generate start refactor * Add tf tests for sample generate * re-organize * boom boom * Apply suggestions from code review * re-add * add all code * make random greedy pass * make encoder-decoder random work * further improvements * delete bogus file * make gpt2 and t5 tests work * finish logits tests * correct logits processors * correct past / encoder_outputs drama * refactor some methods * another fix * refactor shape_list * fix more shape list * import shape _list * finish docs * fix imports * make style * correct tf utils * Fix TFRag as well * Apply Lysandre's and Sylvais suggestions * Update tests/test_generation_tf_logits_process.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/tf_utils.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * remove cpu according to gante * correct logit processor Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a3dbbc3467
commit
2e12b907ae
@@ -955,7 +955,7 @@ class TFModelTesterMixin:
|
||||
# Models with non-text inputs won't work here; num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
||||
|
||||
Reference in New Issue
Block a user