From 961c69776f8a2c95b92407a086848ebca037de5d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 7 Feb 2020 08:53:17 +0100 Subject: [PATCH] @julien-c proposal for TF/PT compat in hf_buckets --- src/transformers/modeling_tf_utils.py | 6 +----- src/transformers/modeling_utils.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4b64f9364c..e5b00a68d7 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -303,11 +303,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" else: - archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME) - if from_pt: - raise EnvironmentError( - "Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name." - ) + archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME)) # redirect to the cache, if necessary try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4178e8ca1e..d8389e647e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -421,11 +421,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) archive_file = pretrained_model_name_or_path + ".index" else: - archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME) - if from_tf: - raise EnvironmentError( - "Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name." - ) + archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME)) # redirect to the cache, if necessary try: