From 0e1869cc286d607f1598506be7bd1312b76ca82c Mon Sep 17 00:00:00 2001 From: Setu Shah Date: Wed, 3 Jun 2020 22:25:08 -0700 Subject: [PATCH] Add drop_last arg for data loader --- src/transformers/trainer.py | 2 ++ src/transformers/training_args.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d0914aa7a5..2caccdded9 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -240,6 +240,7 @@ class Trainer: batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator.collate_batch, + drop_last=self.args.dataloader_drop_last, ) return data_loader @@ -264,6 +265,7 @@ class Trainer: sampler=sampler, batch_size=self.args.eval_batch_size, collate_fn=self.data_collator.collate_batch, + drop_last=self.args.dataloader_drop_last, ) return data_loader diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d46e3e9c5b..c571f1c083 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -133,6 +133,10 @@ class TrainingArguments: ) tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"}) + dataloader_drop_last: bool = field( + default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} + ) + @property def train_batch_size(self) -> int: if self.per_gpu_train_batch_size: