Fix cardinality (#9505)

This commit is contained in:
Julien Plu
2021-01-11 15:42:19 +01:00
committed by GitHub
parent 33b7422839
commit ba702966ba

View File

@@ -135,7 +135,7 @@ class TFTrainer:
raise ValueError("Trainer: training requires a train_dataset.")
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
self.num_train_examples = self.train_dataset.cardinality(self.train_dataset).numpy()
self.num_train_examples = self.train_dataset.cardinality().numpy()
if self.num_train_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
@@ -167,7 +167,7 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
num_examples = eval_dataset.cardinality(eval_dataset).numpy()
num_examples = eval_dataset.cardinality().numpy()
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
@@ -197,7 +197,7 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior.
"""
num_examples = test_dataset.cardinality(test_dataset).numpy()
num_examples = test_dataset.cardinality().numpy()
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")