Rewrite TensorFlow train_step and test_step (#17057)
* Initial commit * Better label renaming * Remove breakpoint before pushing (this is your job) * Test a lot more in the Keras fit() test * make fixup * Clarify the case where we flatten y dicts into tensors * Clarify the case where we flatten y dicts into tensors * Extract label name remapping to a method
This commit is contained in:
@@ -1355,7 +1355,25 @@ class TFModelTesterMixin:
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||
accuracy_classes = [
|
||||
"ForPreTraining",
|
||||
"ForCausalLM",
|
||||
"ForMaskedLM",
|
||||
"ForQuestionAnswering",
|
||||
"ForMultipleChoice",
|
||||
"ForSequenceClassification",
|
||||
"ForTokenClassification",
|
||||
"ForNextSentencePrediction",
|
||||
"LMHeadModel",
|
||||
]
|
||||
for accuracy_class in accuracy_classes:
|
||||
if model.__class__.__name__.endswith(accuracy_class):
|
||||
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
|
||||
break
|
||||
else:
|
||||
metrics = []
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
@@ -1365,6 +1383,7 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
@@ -1374,7 +1393,14 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||
for key in history1.history.keys():
|
||||
if not key.startswith("val_"):
|
||||
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
||||
|
||||
def test_int64_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user