Kill model archive maps (#4636)
* Kill model archive maps * Fixup * Also kill model_archive_map for MaskedBertPreTrainedModel * Unhook config_archive_map * Tokenizers: align with model id changes * make style && make quality * Fix CI
This commit is contained in:
@@ -19,7 +19,6 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@@ -31,7 +30,6 @@ class MaskedBertConfig(PretrainedConfig):
|
||||
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "masked_bert"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -29,12 +29,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from emmental import MaskedBertConfig
|
||||
from emmental.modules import MaskedLinear
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_bert import (
|
||||
ACT2FN,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertLayerNorm,
|
||||
load_tf_weights_in_bert,
|
||||
)
|
||||
from transformers.modeling_bert import ACT2FN, BertLayerNorm, load_tf_weights_in_bert
|
||||
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
|
||||
|
||||
@@ -395,7 +390,6 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = MaskedBertConfig
|
||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
||||
@@ -53,8 +53,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
|
||||
@@ -576,7 +574,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
|
||||
@@ -57,8 +57,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer),
|
||||
@@ -673,7 +671,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
||||
Reference in New Issue
Block a user