Fix dataset cardinality (#7678)
* Fix test * Fix cardinality issue * Fix test
This commit is contained in:
@@ -96,6 +96,9 @@ def get_tfds(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if train_ds is not None:
|
||||||
|
train_ds = train_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TRAIN])))
|
||||||
|
|
||||||
val_ds = (
|
val_ds = (
|
||||||
tf.data.Dataset.from_generator(
|
tf.data.Dataset.from_generator(
|
||||||
gen_val,
|
gen_val,
|
||||||
@@ -106,6 +109,9 @@ def get_tfds(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if val_ds is not None:
|
||||||
|
val_ds = val_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.VALIDATION])))
|
||||||
|
|
||||||
test_ds = (
|
test_ds = (
|
||||||
tf.data.Dataset.from_generator(
|
tf.data.Dataset.from_generator(
|
||||||
gen_test,
|
gen_test,
|
||||||
@@ -116,6 +122,9 @@ def get_tfds(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if test_ds is not None:
|
||||||
|
test_ds = test_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TEST])))
|
||||||
|
|
||||||
return train_ds, val_ds, test_ds, label2id
|
return train_ds, val_ds, test_ds, label2id
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user