From 528c288fa94b0a258c610559c52fbc3a03e46805 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 27 Sep 2019 09:40:29 +0200 Subject: [PATCH] clean up run_tf_glue --- examples/run_tf_glue.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/run_tf_glue.py b/examples/run_tf_glue.py index 3a867f80a8..9235612cb0 100644 --- a/examples/run_tf_glue.py +++ b/examples/run_tf_glue.py @@ -23,12 +23,6 @@ model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) history = model.fit(train_dataset, epochs=2, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7) ->>> Train for 115 steps, validate for 7 steps ->>> Epoch 1/2 ->>> 115/115 [==============================] - 53s 459ms/step - loss: 0.6033 - accuracy: 0.6712 - val_loss: 0.4964 - val_accuracy: 0.7647 ->>> Epoch 2/2 ->>> 115/115 [==============================] - 33s 289ms/step - loss: 0.4141 - accuracy: 0.8160 - val_loss: 0.3914 - val_accuracy: 0.8382 - # Load the TensorFlow model in PyTorch for inspection model.save_pretrained('./save/') pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True) @@ -44,5 +38,3 @@ pred_1 = pytorch_model(**inputs_1)[0].argmax().item() pred_2 = pytorch_model(**inputs_2)[0].argmax().item() print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0") print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0") ->>> sentence_1 is a paraphrase of sentence_0 ->>> sentence_2 is not a paraphrase of sentence_0 \ No newline at end of file