update loading logics
This commit is contained in:
@@ -163,7 +163,7 @@ if _tf_available and _torch_available:
|
|||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
cached_path, add_start_docstrings, add_end_docstrings,
|
cached_path, add_start_docstrings, add_end_docstrings,
|
||||||
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ except (AttributeError, ImportError):
|
|||||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
TF2_WEIGHTS_NAME = 'tf_model.h5'
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import os
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
|
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path is not None:
|
||||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
if from_pt:
|
|
||||||
# Load from a PyTorch checkpoint
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
|
||||||
else:
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
|
|
||||||
else:
|
|
||||||
archive_file = pretrained_model_name_or_path
|
|
||||||
# redirect to the cache, if necessary
|
|
||||||
try:
|
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
|
||||||
except EnvironmentError:
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
logger.error(
|
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||||
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
archive_file))
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
|
# Load from a TF 2.0 checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
|
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
|
# Load from a PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError("Error no file named {} found in directory {}".format(
|
||||||
|
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME),
|
||||||
|
pretrained_model_name_or_path))
|
||||||
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
|
archive_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
logger.error(
|
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
# redirect to the cache, if necessary
|
||||||
"associated to this path or url.".format(
|
try:
|
||||||
pretrained_model_name_or_path,
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
', '.join(cls.pretrained_model_archive_map.keys()),
|
except EnvironmentError as e:
|
||||||
archive_file))
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
return None
|
logger.error(
|
||||||
if resolved_archive_file == archive_file:
|
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
||||||
logger.info("loading weights file {}".format(archive_file))
|
archive_file))
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
|
"associated to this path or url.".format(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
', '.join(cls.pretrained_model_archive_map.keys()),
|
||||||
|
archive_file))
|
||||||
|
raise e
|
||||||
|
if resolved_archive_file == archive_file:
|
||||||
|
logger.info("loading weights file {}".format(archive_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading weights file {} from cache at {}".format(
|
||||||
|
archive_file, resolved_archive_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading weights file {} from cache at {}".format(
|
resolved_archive_file = None
|
||||||
archive_file, resolved_archive_file))
|
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
|
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module):
|
|||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
if from_tf:
|
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
||||||
# Directly load from a TensorFlow checkpoint
|
# Load from a TF 1.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||||
else:
|
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
|
# Load from a TF 2.0 checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError("Error no file named {} found in directory {}".format(
|
||||||
|
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"),
|
||||||
|
pretrained_model_name_or_path))
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user