Rewrite TensorFlow train_step and test_step (#17057)

* Initial commit

* Better label renaming

* Remove breakpoint before pushing (this is your job)

* Test a lot more in the Keras fit() test

* make fixup

* Clarify the case where we flatten y dicts into tensors

* Clarify the case where we flatten y dicts into tensors

* Extract label name remapping to a method
This commit is contained in:
Matt
2022-05-17 14:36:23 +01:00
committed by GitHub
parent 651e48e1e5
commit 349f1c85d3
2 changed files with 148 additions and 47 deletions

View File

@@ -1355,7 +1355,25 @@ class TFModelTesterMixin:
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
accuracy_classes = [
"ForPreTraining",
"ForCausalLM",
"ForMaskedLM",
"ForQuestionAnswering",
"ForMultipleChoice",
"ForSequenceClassification",
"ForTokenClassification",
"ForNextSentencePrediction",
"LMHeadModel",
]
for accuracy_class in accuracy_classes:
if model.__class__.__name__.endswith(accuracy_class):
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
break
else:
metrics = []
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
@@ -1365,6 +1383,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
history2 = model.fit(
inputs_minus_labels,
labels,
@@ -1374,7 +1393,14 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()