Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
This commit is contained in:
Julien Chaumond
2020-06-02 09:39:33 -04:00
committed by GitHub
parent 47a551d17b
commit d4c2cb402d
115 changed files with 792 additions and 1323 deletions

View File

@@ -110,6 +110,9 @@ class ModuleUtilsMixin:
@property
def device(self) -> device:
"""
Get torch.device from module, assuming that the whole module has one device.
"""
try:
return next(self.parameters()).device
except StopIteration:
@@ -125,6 +128,9 @@ class ModuleUtilsMixin:
@property
def dtype(self) -> dtype:
"""
Get torch.dtype from module, assuming that the whole module has one dtype.
"""
try:
return next(self.parameters()).dtype
except StopIteration:
@@ -249,7 +255,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
@@ -259,7 +264,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
"""
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
@property
@@ -587,9 +591,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Load model
if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
@@ -622,8 +624,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
use_cdn=use_cdn,
)
# redirect to the cache, if necessary
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
@@ -632,20 +634,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
resume_download=resume_download,
local_files_only=local_files_only,
)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
else:
msg = (
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to model weight files named one of {} but "
"couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
)
)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file: