allowing from_pretrained to load from url directly

This commit is contained in:
thomwolf
2019-12-11 17:19:18 +01:00
parent 2e2f9fed55
commit 29570db25b
2 changed files with 8 additions and 3 deletions

View File

@@ -259,8 +259,10 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
else:
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:

View File

@@ -365,9 +365,12 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
pretrained_model_name_or_path + ".index")
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try: