From 18e1f751f1d996c4fe01559ade1cd013186b81e4 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 11 Dec 2019 17:07:46 -0500 Subject: [PATCH] TF support --- transformers/modeling_tf_utils.py | 9 ++++++--- transformers/modeling_utils.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index e7512b5bd6..4a6d18f447 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -24,7 +24,8 @@ import os import tensorflow as tf from .configuration_utils import PretrainedConfig -from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME +from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, + cached_path, hf_bucket_url, is_remote_url) from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -257,12 +258,14 @@ class TFPreTrainedModel(tf.keras.Model): raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) - elif os.path.isfile(pretrained_model_name_or_path): + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" else: - archive_file = pretrained_model_name_or_path + archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME) + if from_pt: + raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.") # redirect to the cache, if necessary try: diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index eac4252336..37088f8e67 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -372,7 +372,8 @@ class PreTrainedModel(nn.Module): archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME) - # todo do we want to support TF checkpoints here? + if from_tf: + raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.") # redirect to the cache, if necessary try: