run_tf_glue MRPC evaluation only for MRPC

This commit is contained in:
Lysandre
2019-10-31 18:00:51 -04:00
parent be36cf92fb
commit 1a2b40cb53

View File

@@ -71,20 +71,21 @@ history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
os.makedirs('./save/', exist_ok=True) os.makedirs('./save/', exist_ok=True)
model.save_pretrained('./save/') model.save_pretrained('./save/')
# Load the TensorFlow model in PyTorch for inspection if TASK == "mrpc":
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True) # Load the TensorFlow model in PyTorch for inspection
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task # Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
sentence_0 = 'This research was consistent with his findings.' sentence_0 = 'This research was consistent with his findings.'
sentence_1 = 'His findings were compatible with this research.' sentence_1 = 'His findings were compatible with this research.'
sentence_2 = 'His findings were not compatible with this research.' sentence_2 = 'His findings were not compatible with this research.'
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt') inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt') inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
del inputs_1["special_tokens_mask"] del inputs_1["special_tokens_mask"]
del inputs_2["special_tokens_mask"] del inputs_2["special_tokens_mask"]
pred_1 = pytorch_model(**inputs_1)[0].argmax().item() pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
pred_2 = pytorch_model(**inputs_2)[0].argmax().item() pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0') print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0') print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')