allowing from_pretrained to load from url directly
This commit is contained in:
@@ -259,8 +259,10 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
|
||||
archive_file = pretrained_model_name_or_path
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
||||
@@ -365,9 +365,12 @@ class PreTrainedModel(nn.Module):
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
||||
pretrained_model_name_or_path + ".index")
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = pretrained_model_name_or_path
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user