Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
This commit is contained in:
Julien Chaumond
2020-06-02 09:39:33 -04:00
committed by GitHub
parent 47a551d17b
commit d4c2cb402d
115 changed files with 792 additions and 1323 deletions

View File

@@ -112,7 +112,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
@@ -122,7 +121,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
"""
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
@property
@@ -338,9 +336,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Load model
if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
@@ -364,8 +360,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
use_cdn=use_cdn,
)
# redirect to the cache, if necessary
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
@@ -373,20 +369,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
resume_download=resume_download,
proxies=proxies,
)
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
)
)
raise e
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else: