Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
This commit is contained in:
Sylvain Gugger
2021-04-08 18:41:36 -04:00
committed by GitHub
parent 97ccf67bb3
commit ba8b1f4754
26 changed files with 188 additions and 72 deletions

View File

@@ -19,6 +19,8 @@ import os
import re
from pathlib import Path
from transformers.models.auto import get_values
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_repo.py
@@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"RagModel",
@@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"T5Stack",
"TFDPRReader",
"TFDPRSpanPredictor",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
@@ -153,7 +153,7 @@ def get_model_modules():
def get_models(module):
""" Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel)
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name:
continue
@@ -249,10 +249,13 @@ def get_all_auto_configured_models():
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values())
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values())
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result]