TF generate refactor - Greedy Search (#15562)

* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2022-02-15 17:54:43 +01:00
committed by GitHub
parent a3dbbc3467
commit 2e12b907ae
56 changed files with 1491 additions and 222 deletions

View File

@@ -955,7 +955,7 @@ class TFModelTesterMixin:
# Models with non-text inputs won't work here; num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
# generating multiple sequences when no beam search generation
# is not allowed as it would always generate the same sequences
model.generate(input_ids, do_sample=False, num_return_sequences=2)