From 29570db25ba9dd30e5ac9be68dbcad95434964ec Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 11 Dec 2019 17:19:18 +0100 Subject: [PATCH] allowing from_pretrained to load from url directly --- transformers/modeling_tf_utils.py | 4 +++- transformers/modeling_utils.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index ed8fdb74c9..e7512b5bd6 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -259,8 +259,10 @@ class TFPreTrainedModel(tf.keras.Model): pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + archive_file = pretrained_model_name_or_path + ".index" else: - raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path)) + archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 3ac568771e..9e7ca8d689 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -365,9 +365,12 @@ class PreTrainedModel(nn.Module): pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path - else: - assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( + pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" + else: + archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: