Add drop_last arg for data loader

This commit is contained in:
Setu Shah
2020-06-03 22:25:08 -07:00
committed by Julien Chaumond
parent 48a05026de
commit 0e1869cc28
2 changed files with 6 additions and 0 deletions

View File

@@ -240,6 +240,7 @@ class Trainer:
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator.collate_batch,
drop_last=self.args.dataloader_drop_last,
)
return data_loader
@@ -264,6 +265,7 @@ class Trainer:
sampler=sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch,
drop_last=self.args.dataloader_drop_last,
)
return data_loader

View File

@@ -133,6 +133,10 @@ class TrainingArguments:
)
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"})
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
@property
def train_batch_size(self) -> int:
if self.per_gpu_train_batch_size: