Make DataCollator a callable (#5015)
* Make DataCollator a callable * Update src/transformers/data/data_collator.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -38,7 +38,6 @@ from transformers import (
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
DefaultDataCollator,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
@@ -51,6 +50,7 @@ from transformers import (
|
||||
XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer,
|
||||
default_data_collator,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils_hans import HansDataset, hans_output_modes, hans_processors
|
||||
@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=DefaultDataCollator().collate_batch,
|
||||
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
|
||||
)
|
||||
|
||||
if args.max_steps > 0:
|
||||
@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.eval_batch_size,
|
||||
collate_fn=DefaultDataCollator().collate_batch,
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
|
||||
)
|
||||
|
||||
# multi-gpu eval
|
||||
|
||||
Reference in New Issue
Block a user