From 1a7ef3349fd9acfc8ee7bc718f9864ac4ea9d064 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 Jun 2022 15:59:53 +0100 Subject: [PATCH] Fix broken test for models with batchnorm (#17841) * Fix tests that broke when models used batchnorm * Initializing the model twice does not actually... ...give you the same weights each time. I am good at machine learning. * Fix speed regression --- tests/test_modeling_tf_common.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 843ddaa5e3..1d09972520 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1383,6 +1383,10 @@ class TFModelTesterMixin: else: metrics = [] + model(model.dummy_inputs) # Build the model so we can get some constant weights + model_weights = model.get_weights() + + # Run eagerly to save some expensive compilation times model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics) # Make sure the model fits without crashing regardless of where we pass the labels history1 = model.fit( @@ -1394,6 +1398,11 @@ class TFModelTesterMixin: ) val_loss1 = history1.history["val_loss"][0] accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} + + # We reinitialize the model here even though our learning rate was zero + # because BatchNorm updates weights by means other than gradient descent. + model.set_weights(model_weights) + history2 = model.fit( inputs_minus_labels, labels, @@ -1403,7 +1412,7 @@ class TFModelTesterMixin: shuffle=False, ) val_loss2 = history2.history["val_loss"][0] - accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} + accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")} self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertEqual(history1.history.keys(), history2.history.keys()) for key in history1.history.keys(): @@ -1416,6 +1425,10 @@ class TFModelTesterMixin: dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) # Pass in all samples as a batch to match other `fit` calls dataset = dataset.batch(len(dataset)) + + # Reinitialize to fix batchnorm again + model.set_weights(model_weights) + history3 = model.fit( dataset, validation_data=dataset,