TFMarian, TFMbart, TFPegasus, TFBlenderbot (#7987)

* Start plumbing

* Marian close

* Small stubs for all children

* Fixed bart

* marian working

* pegasus test is good, but failing

* Checkin tests

* More model files

* Subtle marian, pegasus integration test failures

* Works well

* rm print

* boom boom

* Still failing model2doc

* merge master

* Equivalence test failing, all others fixed

* cleanup

* Fix embed_scale

* Cleanup marian pipeline test

* Undo extra changes

* Smaller delta

* Cleanup model testers

* undo delta

* fix tests import structure

* cross test decorator

* Cleaner set_weights

* Respect authorized_unexpected_keys

* No warnings

* No warnings

* style

* Nest tf import

* black

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* functional dropout

* fixup

* Fixup

* style_doc

* embs

* shape list

* delete slow force_token_id_to_be_generated func

* fixup

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sam Shleifer
2020-10-30 11:23:16 -04:00
committed by GitHub
parent 6279072f5f
commit 566b083eb1
20 changed files with 1063 additions and 106 deletions

View File

@@ -67,6 +67,7 @@ MODEL_NAME_TO_DOC_FILE = {
"xlm_prophetnet": "xlmprophetnet.rst",
"xlm_roberta": "xlmroberta.rst",
"bert_generation": "bertgeneration.rst",
"marian": "marian.rst",
}
# This is to make sure the transformers module imported is the one in the repo.
@@ -148,7 +149,6 @@ def get_model_doc_files():
_ignore_modules = [
"auto",
"dialogpt",
"marian",
"retribert",
]
doc_files = []
@@ -245,6 +245,7 @@ def check_models_are_documented(module, doc_file):
def _get_model_name(module):
""" Get the model name for the module defining it."""
splits = module.__name__.split("_")
# Secial case for transfo_xl
if splits[-1] == "xl":
return "_".join(splits[-2:])