Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
@@ -19,6 +19,8 @@ import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.models.auto import get_values
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_repo.py
|
||||
@@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"DPRReader",
|
||||
"DPRSpanPredictor",
|
||||
"FlaubertForQuestionAnswering",
|
||||
"FunnelBaseModel",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"OpenAIGPTDoubleHeadsModel",
|
||||
"RagModel",
|
||||
@@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"T5Stack",
|
||||
"TFDPRReader",
|
||||
"TFDPRSpanPredictor",
|
||||
"TFFunnelBaseModel",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFRagModel",
|
||||
@@ -153,7 +153,7 @@ def get_model_modules():
|
||||
def get_models(module):
|
||||
""" Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel)
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
if "Pretrained" in attr_name or "PreTrained" in attr_name:
|
||||
continue
|
||||
@@ -249,10 +249,13 @@ def get_all_auto_configured_models():
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values())
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
||||
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values())
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls.__name__ for cls in result]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user