Fix broken test for models with batchnorm (#17841)
* Fix tests that broke when models used batchnorm * Initializing the model twice does not actually... ...give you the same weights each time. I am good at machine learning. * Fix speed regression
This commit is contained in:
@@ -1383,6 +1383,10 @@ class TFModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
|
model(model.dummy_inputs) # Build the model so we can get some constant weights
|
||||||
|
model_weights = model.get_weights()
|
||||||
|
|
||||||
|
# Run eagerly to save some expensive compilation times
|
||||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
|
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
|
||||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||||
history1 = model.fit(
|
history1 = model.fit(
|
||||||
@@ -1394,6 +1398,11 @@ class TFModelTesterMixin:
|
|||||||
)
|
)
|
||||||
val_loss1 = history1.history["val_loss"][0]
|
val_loss1 = history1.history["val_loss"][0]
|
||||||
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||||
|
|
||||||
|
# We reinitialize the model here even though our learning rate was zero
|
||||||
|
# because BatchNorm updates weights by means other than gradient descent.
|
||||||
|
model.set_weights(model_weights)
|
||||||
|
|
||||||
history2 = model.fit(
|
history2 = model.fit(
|
||||||
inputs_minus_labels,
|
inputs_minus_labels,
|
||||||
labels,
|
labels,
|
||||||
@@ -1403,7 +1412,7 @@ class TFModelTesterMixin:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
val_loss2 = history2.history["val_loss"][0]
|
val_loss2 = history2.history["val_loss"][0]
|
||||||
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
|
||||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||||
for key in history1.history.keys():
|
for key in history1.history.keys():
|
||||||
@@ -1416,6 +1425,10 @@ class TFModelTesterMixin:
|
|||||||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||||
# Pass in all samples as a batch to match other `fit` calls
|
# Pass in all samples as a batch to match other `fit` calls
|
||||||
dataset = dataset.batch(len(dataset))
|
dataset = dataset.batch(len(dataset))
|
||||||
|
|
||||||
|
# Reinitialize to fix batchnorm again
|
||||||
|
model.set_weights(model_weights)
|
||||||
|
|
||||||
history3 = model.fit(
|
history3 = model.fit(
|
||||||
dataset,
|
dataset,
|
||||||
validation_data=dataset,
|
validation_data=dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user