[AutoModels] Fix config params handling of all PT and TF AutoModels (#5665)
* fix auto model causal lm * leverage given functionality * apply unused kwargs to all auto models
This commit is contained in:
committed by
GitHub
parent
8ab565a4be
commit
ec0a945cf9
@@ -498,7 +498,9 @@ class AutoModel:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -645,7 +647,9 @@ class AutoModelForPreTraining:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -802,7 +806,9 @@ class AutoModelWithLMHead:
|
||||
)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -937,7 +943,9 @@ class AutoModelForCausalLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice:
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
|
||||
@@ -450,7 +450,9 @@ class TFAutoModel(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object):
|
||||
config = kwargs.pop("config", None)
|
||||
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
# Not using isinstance() here to do not take into account inheritance
|
||||
@@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
|
||||
Reference in New Issue
Block a user