Kill model archive maps (#4636)
* Kill model archive maps * Fixup * Also kill model_archive_map for MaskedBertPreTrainedModel * Unhook config_archive_map * Tokenizers: align with model id changes * make style && make quality * Fix CI
This commit is contained in:
@@ -110,6 +110,9 @@ class ModuleUtilsMixin:
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
"""
|
||||
Get torch.device from module, assuming that the whole module has one device.
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).device
|
||||
except StopIteration:
|
||||
@@ -125,6 +128,9 @@ class ModuleUtilsMixin:
|
||||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
"""
|
||||
Get torch.dtype from module, assuming that the whole module has one dtype.
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
@@ -249,7 +255,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
|
||||
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
||||
|
||||
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
|
||||
@@ -259,7 +264,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
|
||||
@property
|
||||
@@ -587,9 +591,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
||||
# Load from a TF 1.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
@@ -622,8 +624,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
use_cdn=use_cdn,
|
||||
)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
@@ -632,20 +634,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
|
||||
else:
|
||||
msg = (
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url to model weight files named one of {} but "
|
||||
"couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(cls.pretrained_model_archive_map.keys()),
|
||||
archive_file,
|
||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
|
||||
)
|
||||
)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
|
||||
Reference in New Issue
Block a user