Model templates encoder only (#8509)
* Model templates * TensorFlow * Remove pooler * CI * Tokenizer + Refactoring * Encoder-Decoder * Let's go testing * Encoder-Decoder in TF * Let's go testing in TF * Documentation * README * Fixes * Better names * Style * Update docs * Choose to skip either TF or PT * Code quality fixes * Add to testing suite * Update file path * Cookiecutter path * Update `transformers` path * Handle rebasing * Remove seq2seq from model templates * Remove s2s config * Apply Sylvain and Patrick comments * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Last fixes from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -279,20 +279,9 @@ def check_models_are_documented(module, doc_file):
|
||||
def _get_model_name(module):
|
||||
""" Get the model name for the module defining it."""
|
||||
splits = module.__name__.split("_")
|
||||
splits = splits[(2 if splits[1] in ["flax", "tf"] else 1) :]
|
||||
|
||||
# Secial case for transfo_xl
|
||||
if splits[-1] == "xl":
|
||||
return "_".join(splits[-2:])
|
||||
# Special case for xlm_prophetnet
|
||||
if splits[-1] == "prophetnet" and splits[-2] == "xlm":
|
||||
return "_".join(splits[-2:])
|
||||
# Secial case for xlm_roberta
|
||||
if splits[-1] == "roberta" and splits[-2] == "xlm":
|
||||
return "_".join(splits[-2:])
|
||||
# Special case for bert_generation
|
||||
if splits[-1] == "generation" and splits[-2] == "bert":
|
||||
return "_".join(splits[-2:])
|
||||
return splits[-1]
|
||||
return "_".join(splits)
|
||||
|
||||
|
||||
def check_all_models_are_documented():
|
||||
|
||||
Reference in New Issue
Block a user