From ec0a945cf94927051bb99748346dc7e0848af081 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jul 2020 09:51:14 +0200 Subject: [PATCH] [AutoModels] Fix config params handling of all PT and TF AutoModels (#5665) * fix auto model causal lm * leverage given functionality * apply unused kwargs to all auto models --- src/transformers/modeling_auto.py | 40 +++++++++++++++++++++------- src/transformers/modeling_tf_auto.py | 40 +++++++++++++++++++++------- 2 files changed, 60 insertions(+), 20 deletions(-) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 7991254034..d59f8d6028 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -498,7 +498,9 @@ class AutoModel: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_MAPPING.items(): if isinstance(config, config_class): @@ -645,7 +647,9 @@ class AutoModelForPreTraining: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items(): if isinstance(config, config_class): @@ -802,7 +806,9 @@ class AutoModelWithLMHead: ) config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items(): if isinstance(config, config_class): @@ -937,7 +943,9 @@ class AutoModelForCausalLM: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): if isinstance(config, config_class): @@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): if isinstance(config, config_class): @@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): if isinstance(config, config_class): @@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): if isinstance(config, config_class): @@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if isinstance(config, config_class): @@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): if isinstance(config, config_class): @@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice: def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): if isinstance(config, config_class): diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index bc954a7f4a..a47c0f30ad 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -450,7 +450,9 @@ class TFAutoModel(object): """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_MAPPING.items(): if isinstance(config, config_class): @@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object): """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items(): if isinstance(config, config_class): @@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object): config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items(): # Not using isinstance() here to do not take into account inheritance @@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): if isinstance(config, config_class): @@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items(): if isinstance(config, config_class): @@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items(): if isinstance(config, config_class): @@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): if isinstance(config, config_class): @@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object): """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): if isinstance(config, config_class): @@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object): """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if isinstance(config, config_class): @@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification: """ config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): if isinstance(config, config_class):