@julien-c proposal for TF/PT compat in hf_buckets
This commit is contained in:
@@ -303,11 +303,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME))
|
||||||
if from_pt:
|
|
||||||
raise EnvironmentError(
|
|
||||||
"Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name."
|
|
||||||
)
|
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -421,11 +421,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
)
|
)
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME))
|
||||||
if from_tf:
|
|
||||||
raise EnvironmentError(
|
|
||||||
"Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name."
|
|
||||||
)
|
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user