From f5397ffc3bf444e814b4234526dccba146be0347 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 24 Sep 2019 14:03:58 +0200 Subject: [PATCH] update loading logics --- pytorch_transformers/__init__.py | 2 +- pytorch_transformers/file_utils.py | 1 + pytorch_transformers/modeling_tf_utils.py | 71 +++++++++++++---------- pytorch_transformers/modeling_utils.py | 16 +++-- 4 files changed, 55 insertions(+), 35 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index b8c7eccfe7..508d0f84c4 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -163,7 +163,7 @@ if _tf_available and _torch_available: # Files and general utilities from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings, - WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) + WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) def is_torch_available(): return _torch_available diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index a656e757b5..34333aaafb 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -49,6 +49,7 @@ except (AttributeError, ImportError): PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility WEIGHTS_NAME = "pytorch_model.bin" +TF2_WEIGHTS_NAME = 'tf_model.h5' TF_WEIGHTS_NAME = 'model.ckpt' CONFIG_NAME = "config.json" diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py index f2b3623a3c..4a695e2864 100644 --- a/pytorch_transformers/modeling_tf_utils.py +++ b/pytorch_transformers/modeling_tf_utils.py @@ -24,7 +24,7 @@ import os import tensorflow as tf from .configuration_utils import PretrainedConfig -from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME +from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME logger = logging.getLogger(__name__) @@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model): model_kwargs = kwargs # Load model - if pretrained_model_name_or_path in cls.pretrained_model_archive_map: - archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] - elif os.path.isdir(pretrained_model_name_or_path): - if from_pt: - # Load from a PyTorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - else: - archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME) - else: - archive_file = pretrained_model_name_or_path - # redirect to the cache, if necessary - try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError: + if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: - logger.error( - "Couldn't reach server at '{}' to download pretrained weights.".format( - archive_file)) + archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): + # Load from a TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + raise EnvironmentError("Error no file named {} found in directory {}".format( + tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME), + pretrained_model_name_or_path)) + elif os.path.isfile(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name_or_path, - ', '.join(cls.pretrained_model_archive_map.keys()), - archive_file)) - return None - if resolved_archive_file == archive_file: - logger.info("loading weights file {}".format(archive_file)) + raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path)) + + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) + except EnvironmentError as e: + if pretrained_model_name_or_path in cls.pretrained_model_archive_map: + logger.error( + "Couldn't reach server at '{}' to download pretrained weights.".format( + archive_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(cls.pretrained_model_archive_map.keys()), + archive_file)) + raise e + if resolved_archive_file == archive_file: + logger.info("loading weights file {}".format(archive_file)) + else: + logger.info("loading weights file {} from cache at {}".format( + archive_file, resolved_archive_file)) else: - logger.info("loading weights file {} from cache at {}".format( - archive_file, resolved_archive_file)) + resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index aadf410a19..790d4dcd54 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss from torch.nn import functional as F from .configuration_utils import PretrainedConfig -from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME +from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME logger = logging.getLogger(__name__) @@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module): if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): - if from_tf: - # Directly load from a TensorFlow checkpoint + if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): + # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") - else: + elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): + # Load from a TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + raise EnvironmentError("Error no file named {} found in directory {}".format( + tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"), + pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: