From 3b7fb48c3b963a88809263140e56108ac15225f8 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 25 Sep 2019 17:46:16 +0200 Subject: [PATCH] fix loading from tf/pt --- examples/run_tf_glue.py | 3 ++- pytorch_transformers/modeling_tf_utils.py | 4 ++-- pytorch_transformers/modeling_utils.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/run_tf_glue.py b/examples/run_tf_glue.py index 6f59d15286..4328ff5170 100644 --- a/examples/run_tf_glue.py +++ b/examples/run_tf_glue.py @@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/') pt_model = BertForSequenceClassification.from_pretrained('./runs/') # Quickly inspect a few predictions - +inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True) +pred = pt_model(torch.tensor([tokens])) # Divers import torch diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py index 21faee6616..e9db995b4d 100644 --- a/pytorch_transformers/modeling_tf_utils.py +++ b/pytorch_transformers/modeling_tf_utils.py @@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: - raise EnvironmentError("Error no file named {} found in directory {}".format( - tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME), + raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format( + [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 00e1156125..541ef7c741 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: - raise EnvironmentError("Error no file named {} found in directory {}".format( + raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path):