Slightly alter Keras dummy loss (#20232)

* Slightly alter Keras dummy loss

* Slightly alter Keras dummy loss

* Add sample weight to test_keras_fit

* Fix test_keras_fit for datasets

* Skip the sample_weight stuff for models where the model tester has no batch_size
This commit is contained in:
Matt
2022-11-15 16:58:43 +00:00
committed by GitHub
parent 7f74433814
commit 26ec7928d0
2 changed files with 21 additions and 2 deletions

View File

@@ -1544,6 +1544,11 @@ class TFModelTesterMixin:
else:
metrics = []
if hasattr(self.model_tester, "batch_size"):
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
else:
sample_weight = None
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
@@ -1553,6 +1558,7 @@ class TFModelTesterMixin:
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
sample_weight=sample_weight,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
@@ -1588,6 +1594,7 @@ class TFModelTesterMixin:
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
sample_weight=sample_weight,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
@@ -1605,14 +1612,22 @@ class TFModelTesterMixin:
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
if sample_weight is not None:
# Add in the sample weight
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
else:
weighted_dataset = dataset
# Pass in all samples as a batch to match other `fit` calls
weighted_dataset = weighted_dataset.batch(len(dataset))
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
# To match the other calls, don't pass sample weights in the validation data
history3 = model.fit(
dataset,
weighted_dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,