From 9ad830596d91fee96402e9af4b32a167e9f349dd Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 9 Oct 2020 16:38:25 +0200 Subject: [PATCH] Fix dataset cardinality (#7678) * Fix test * Fix cardinality issue * Fix test --- .../text-classification/run_tf_text_classification.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/text-classification/run_tf_text_classification.py b/examples/text-classification/run_tf_text_classification.py index 40472da47e..657119abb8 100644 --- a/examples/text-classification/run_tf_text_classification.py +++ b/examples/text-classification/run_tf_text_classification.py @@ -96,6 +96,9 @@ def get_tfds( else None ) + if train_ds is not None: + train_ds = train_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TRAIN]))) + val_ds = ( tf.data.Dataset.from_generator( gen_val, @@ -106,6 +109,9 @@ def get_tfds( else None ) + if val_ds is not None: + val_ds = val_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.VALIDATION]))) + test_ds = ( tf.data.Dataset.from_generator( gen_test, @@ -116,6 +122,9 @@ def get_tfds( else None ) + if test_ds is not None: + test_ds = test_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TEST]))) + return train_ds, val_ds, test_ds, label2id