Speed up TF tests by reducing hidden layer counts (#24595)
* hidden layers, huh, what are they good for (absolutely nothing) * Some tests break with 1 hidden layer, use 2 * Use 1 hidden layer in a few slow models * Use num_hidden_layers=2 everywhere * Slightly higher tol for groupvit * Slightly higher tol for groupvit
This commit is contained in:
@@ -1527,36 +1527,6 @@ class TFModelTesterMixin:
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
||||
|
||||
# Make sure fit works with tf.data.Dataset and results are consistent
|
||||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||
|
||||
if sample_weight is not None:
|
||||
# Add in the sample weight
|
||||
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
|
||||
else:
|
||||
weighted_dataset = dataset
|
||||
# Pass in all samples as a batch to match other `fit` calls
|
||||
weighted_dataset = weighted_dataset.batch(len(dataset))
|
||||
dataset = dataset.batch(len(dataset))
|
||||
# Reinitialize to fix batchnorm again
|
||||
model.set_weights(model_weights)
|
||||
|
||||
# To match the other calls, don't pass sample weights in the validation data
|
||||
history3 = model.fit(
|
||||
weighted_dataset,
|
||||
validation_data=dataset,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss3 = history3.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss3))
|
||||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
||||
self.check_keras_fit_results(val_loss1, val_loss3)
|
||||
self.assertEqual(history1.history.keys(), history3.history.keys())
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
|
||||
|
||||
def test_int_support(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user