TF generate refactor - Greedy Search (#15562)

* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2022-02-15 17:54:43 +01:00
committed by GitHub
parent a3dbbc3467
commit 2e12b907ae
56 changed files with 1491 additions and 222 deletions

View File

@@ -53,8 +53,8 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -1803,7 +1803,7 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
); from ...tf_utils import (shape_list,
)
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config