Fix dataset cardinality (#7678)
* Fix test * Fix cardinality issue * Fix test
This commit is contained in:
@@ -96,6 +96,9 @@ def get_tfds(
|
||||
else None
|
||||
)
|
||||
|
||||
if train_ds is not None:
|
||||
train_ds = train_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TRAIN])))
|
||||
|
||||
val_ds = (
|
||||
tf.data.Dataset.from_generator(
|
||||
gen_val,
|
||||
@@ -106,6 +109,9 @@ def get_tfds(
|
||||
else None
|
||||
)
|
||||
|
||||
if val_ds is not None:
|
||||
val_ds = val_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.VALIDATION])))
|
||||
|
||||
test_ds = (
|
||||
tf.data.Dataset.from_generator(
|
||||
gen_test,
|
||||
@@ -116,6 +122,9 @@ def get_tfds(
|
||||
else None
|
||||
)
|
||||
|
||||
if test_ds is not None:
|
||||
test_ds = test_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TEST])))
|
||||
|
||||
return train_ds, val_ds, test_ds, label2id
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user