From 1a2b40cb53477b94c66718bac8d997297fcc8043 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 31 Oct 2019 18:00:51 -0400 Subject: [PATCH] run_tf_glue MRPC evaluation only for MRPC --- examples/run_tf_glue.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/run_tf_glue.py b/examples/run_tf_glue.py index 73173b0cf1..8878ce726e 100644 --- a/examples/run_tf_glue.py +++ b/examples/run_tf_glue.py @@ -71,20 +71,21 @@ history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps, os.makedirs('./save/', exist_ok=True) model.save_pretrained('./save/') -# Load the TensorFlow model in PyTorch for inspection -pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True) +if TASK == "mrpc": + # Load the TensorFlow model in PyTorch for inspection + pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True) -# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task -sentence_0 = 'This research was consistent with his findings.' -sentence_1 = 'His findings were compatible with this research.' -sentence_2 = 'His findings were not compatible with this research.' -inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt') -inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt') + # Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task + sentence_0 = 'This research was consistent with his findings.' + sentence_1 = 'His findings were compatible with this research.' + sentence_2 = 'His findings were not compatible with this research.' + inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt') + inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt') -del inputs_1["special_tokens_mask"] -del inputs_2["special_tokens_mask"] + del inputs_1["special_tokens_mask"] + del inputs_2["special_tokens_mask"] -pred_1 = pytorch_model(**inputs_1)[0].argmax().item() -pred_2 = pytorch_model(**inputs_2)[0].argmax().item() -print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0') -print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0') + pred_1 = pytorch_model(**inputs_1)[0].argmax().item() + pred_2 = pytorch_model(**inputs_2)[0].argmax().item() + print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0') + print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')