From ba702966bae6e2ea53e011a8988aa2ac565d0c96 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 11 Jan 2021 15:42:19 +0100 Subject: [PATCH] Fix cardinality (#9505) --- src/transformers/trainer_tf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index 6f41f349ff..ac75eb6223 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -135,7 +135,7 @@ class TFTrainer: raise ValueError("Trainer: training requires a train_dataset.") self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps - self.num_train_examples = self.train_dataset.cardinality(self.train_dataset).numpy() + self.num_train_examples = self.train_dataset.cardinality().numpy() if self.num_train_examples < 0: raise ValueError("The training dataset must have an asserted cardinality") @@ -167,7 +167,7 @@ class TFTrainer: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - num_examples = eval_dataset.cardinality(eval_dataset).numpy() + num_examples = eval_dataset.cardinality().numpy() if num_examples < 0: raise ValueError("The training dataset must have an asserted cardinality") @@ -197,7 +197,7 @@ class TFTrainer: Subclass and override this method if you want to inject some custom behavior. """ - num_examples = test_dataset.cardinality(test_dataset).numpy() + num_examples = test_dataset.cardinality().numpy() if num_examples < 0: raise ValueError("The training dataset must have an asserted cardinality")