Make DataCollator a callable (#5015)

* Make DataCollator a callable

* Update src/transformers/data/data_collator.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-06-15 11:58:33 -04:00
committed by GitHub
parent f7c93b3cee
commit 1affde2f10
7 changed files with 60 additions and 83 deletions

View File

@@ -38,7 +38,6 @@ from transformers import (
BertConfig,
BertForSequenceClassification,
BertTokenizer,
DefaultDataCollator,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
@@ -51,6 +50,7 @@ from transformers import (
XLNetConfig,
XLNetForSequenceClassification,
XLNetTokenizer,
default_data_collator,
get_linear_schedule_with_warmup,
)
from utils_hans import HansDataset, hans_output_modes, hans_processors
@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
)
if args.max_steps > 0:
@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset,
sampler=eval_sampler,
batch_size=args.eval_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
)
# multi-gpu eval