Use shape_list to safely get shapes for Swin (#17591)

* Use shape_list to safely get shapes

* Add relevant test

* Tidy and add metrics

* Resolve dynamic shaping issues and move test

* Tidy up and all samples in batch

* Formatting
This commit is contained in:
amyeroberts
2022-06-09 15:50:50 +02:00
committed by GitHub
parent e0be053e43
commit 9fc34235fa
4 changed files with 61 additions and 26 deletions

View File

@@ -1406,6 +1406,24 @@ class TFModelTesterMixin:
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset))
history3 = model.fit(
dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: