run_tf_glue MRPC evaluation only for MRPC
This commit is contained in:
@@ -71,20 +71,21 @@ history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
|
|||||||
os.makedirs('./save/', exist_ok=True)
|
os.makedirs('./save/', exist_ok=True)
|
||||||
model.save_pretrained('./save/')
|
model.save_pretrained('./save/')
|
||||||
|
|
||||||
# Load the TensorFlow model in PyTorch for inspection
|
if TASK == "mrpc":
|
||||||
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
|
# Load the TensorFlow model in PyTorch for inspection
|
||||||
|
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
|
||||||
|
|
||||||
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
|
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
|
||||||
sentence_0 = 'This research was consistent with his findings.'
|
sentence_0 = 'This research was consistent with his findings.'
|
||||||
sentence_1 = 'His findings were compatible with this research.'
|
sentence_1 = 'His findings were compatible with this research.'
|
||||||
sentence_2 = 'His findings were not compatible with this research.'
|
sentence_2 = 'His findings were not compatible with this research.'
|
||||||
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
|
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
|
||||||
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
|
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
|
||||||
|
|
||||||
del inputs_1["special_tokens_mask"]
|
del inputs_1["special_tokens_mask"]
|
||||||
del inputs_2["special_tokens_mask"]
|
del inputs_2["special_tokens_mask"]
|
||||||
|
|
||||||
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
||||||
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
||||||
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
|
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
|
||||||
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')
|
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')
|
||||||
|
|||||||
Reference in New Issue
Block a user