remap MODEL_FOR_QUESTION_ANSWERING_MAPPING classes to names auto-generated file (#10487)

* remap classes to strings

* missing new util

* style

* doc

* move the autogenerated file

* Trigger CI
This commit is contained in:
Stas Bekman
2021-03-03 08:54:00 -08:00
committed by GitHub
parent 801ff969ce
commit 188574ac50
4 changed files with 97 additions and 2 deletions

View File

@@ -61,7 +61,6 @@ from .file_utils import (
is_torch_tpu_available,
)
from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
@@ -104,6 +103,7 @@ from .trainer_utils import (
)
from .training_args import ParallelMode, TrainingArguments
from .utils import logging
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
_is_native_amp_available = False
@@ -420,7 +420,7 @@ class Trainer:
self.use_tune_checkpoints = False
default_label_names = (
["start_positions", "end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names

View File

@@ -0,0 +1,34 @@
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify: models/auto/modeling_auto.py
# 2. run: python utils/class_mapping_update.py
from collections import OrderedDict
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("ConvBertConfig", "ConvBertForQuestionAnswering"),
("LEDConfig", "LEDForQuestionAnswering"),
("DistilBertConfig", "DistilBertForQuestionAnswering"),
("AlbertConfig", "AlbertForQuestionAnswering"),
("CamembertConfig", "CamembertForQuestionAnswering"),
("BartConfig", "BartForQuestionAnswering"),
("MBartConfig", "MBartForQuestionAnswering"),
("LongformerConfig", "LongformerForQuestionAnswering"),
("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"),
("RobertaConfig", "RobertaForQuestionAnswering"),
("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"),
("BertConfig", "BertForQuestionAnswering"),
("XLNetConfig", "XLNetForQuestionAnsweringSimple"),
("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"),
("MobileBertConfig", "MobileBertForQuestionAnswering"),
("XLMConfig", "XLMForQuestionAnsweringSimple"),
("ElectraConfig", "ElectraForQuestionAnswering"),
("ReformerConfig", "ReformerForQuestionAnswering"),
("FunnelConfig", "FunnelForQuestionAnswering"),
("LxmertConfig", "LxmertForQuestionAnswering"),
("MPNetConfig", "MPNetForQuestionAnswering"),
("DebertaConfig", "DebertaForQuestionAnswering"),
("DebertaV2Config", "DebertaV2ForQuestionAnswering"),
("IBertConfig", "IBertForQuestionAnswering"),
]
)