clean up run_tf_glue
This commit is contained in:
@@ -23,12 +23,6 @@ model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
history = model.fit(train_dataset, epochs=2, steps_per_epoch=115,
|
||||
validation_data=valid_dataset, validation_steps=7)
|
||||
|
||||
>>> Train for 115 steps, validate for 7 steps
|
||||
>>> Epoch 1/2
|
||||
>>> 115/115 [==============================] - 53s 459ms/step - loss: 0.6033 - accuracy: 0.6712 - val_loss: 0.4964 - val_accuracy: 0.7647
|
||||
>>> Epoch 2/2
|
||||
>>> 115/115 [==============================] - 33s 289ms/step - loss: 0.4141 - accuracy: 0.8160 - val_loss: 0.3914 - val_accuracy: 0.8382
|
||||
|
||||
# Load the TensorFlow model in PyTorch for inspection
|
||||
model.save_pretrained('./save/')
|
||||
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
|
||||
@@ -44,5 +38,3 @@ pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
||||
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
||||
print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")
|
||||
print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")
|
||||
>>> sentence_1 is a paraphrase of sentence_0
|
||||
>>> sentence_2 is not a paraphrase of sentence_0
|
||||
Reference in New Issue
Block a user