Fix dataset cardinality (#7678)

* Fix test

* Fix cardinality issue

* Fix test
This commit is contained in:
Julien Plu
2020-10-09 16:38:25 +02:00
committed by GitHub
parent a1ac082879
commit 9ad830596d

View File

@@ -96,6 +96,9 @@ def get_tfds(
else None
)
if train_ds is not None:
train_ds = train_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TRAIN])))
val_ds = (
tf.data.Dataset.from_generator(
gen_val,
@@ -106,6 +109,9 @@ def get_tfds(
else None
)
if val_ds is not None:
val_ds = val_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.VALIDATION])))
test_ds = (
tf.data.Dataset.from_generator(
gen_test,
@@ -116,6 +122,9 @@ def get_tfds(
else None
)
if test_ds is not None:
test_ds = test_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TEST])))
return train_ds, val_ds, test_ds, label2id