Kill model archive maps (#4636)
* Kill model archive maps * Fixup * Also kill model_archive_map for MaskedBertPreTrainedModel * Unhook config_archive_map * Tokenizers: align with model id changes * make style && make quality * Fix CI
This commit is contained in:
@@ -63,8 +63,6 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@@ -411,7 +409,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
||||
@@ -57,7 +57,6 @@ class XxxConfig(PretrainedConfig):
|
||||
initializing all weight matrices.
|
||||
layer_norm_eps: The epsilon used by LayerNorm.
|
||||
"""
|
||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "xxx"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -32,13 +32,13 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
# This list contrains shortcut names for some of
|
||||
# the pretrained weights provided with the models
|
||||
####################################################
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xxx-base-uncased": "https://cdn.huggingface.co/xxx-base-uncased-tf_model.h5",
|
||||
"xxx-large-uncased": "https://cdn.huggingface.co/xxx-large-uncased-tf_model.h5",
|
||||
}
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xxx-base-uncased",
|
||||
"xxx-large-uncased",
|
||||
]
|
||||
|
||||
|
||||
####################################################
|
||||
@@ -180,7 +180,6 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
||||
@@ -34,13 +34,13 @@ from .modeling_utils import PreTrainedModel
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
# This list contrains shortcut names for some of
|
||||
# the pretrained weights provided with the models
|
||||
####################################################
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xxx-base-uncased": "https://cdn.huggingface.co/xxx-base-uncased-pytorch_model.bin",
|
||||
"xxx-large-uncased": "https://cdn.huggingface.co/xxx-large-uncased-pytorch_model.bin",
|
||||
}
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xxx-base-uncased",
|
||||
"xxx-large-uncased",
|
||||
]
|
||||
|
||||
|
||||
####################################################
|
||||
@@ -180,7 +180,6 @@ class XxxPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xxx
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ if is_torch_available():
|
||||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -269,6 +269,6 @@ class XxxModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in XXX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
Reference in New Issue
Block a user