Fix cardinality (#9505)
This commit is contained in:
@@ -135,7 +135,7 @@ class TFTrainer:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
|
||||
self.num_train_examples = self.train_dataset.cardinality(self.train_dataset).numpy()
|
||||
self.num_train_examples = self.train_dataset.cardinality().numpy()
|
||||
|
||||
if self.num_train_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
@@ -167,7 +167,7 @@ class TFTrainer:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
num_examples = eval_dataset.cardinality(eval_dataset).numpy()
|
||||
num_examples = eval_dataset.cardinality().numpy()
|
||||
|
||||
if num_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
@@ -197,7 +197,7 @@ class TFTrainer:
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
"""
|
||||
|
||||
num_examples = test_dataset.cardinality(test_dataset).numpy()
|
||||
num_examples = test_dataset.cardinality().numpy()
|
||||
|
||||
if num_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
|
||||
Reference in New Issue
Block a user