Make DataCollator a callable (#5015)
* Make DataCollator a callable * Update src/transformers/data/data_collator.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -38,7 +38,6 @@ from transformers import (
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
DefaultDataCollator,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
@@ -51,6 +50,7 @@ from transformers import (
|
||||
XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer,
|
||||
default_data_collator,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils_hans import HansDataset, hans_output_modes, hans_processors
|
||||
@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=DefaultDataCollator().collate_batch,
|
||||
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
|
||||
)
|
||||
|
||||
if args.max_steps > 0:
|
||||
@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.eval_batch_size,
|
||||
collate_fn=DefaultDataCollator().collate_batch,
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
|
||||
)
|
||||
|
||||
# multi-gpu eval
|
||||
|
||||
@@ -34,8 +34,8 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DefaultDataCollator,
|
||||
GlueDataset,
|
||||
default_data_collator,
|
||||
glue_compute_metrics,
|
||||
glue_output_modes,
|
||||
glue_processors,
|
||||
@@ -424,7 +424,7 @@ def main():
|
||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
|
||||
)
|
||||
|
||||
# Compute head entropy and importance score
|
||||
|
||||
@@ -364,7 +364,7 @@ if is_torch_available():
|
||||
|
||||
# Trainer
|
||||
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
|
||||
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
|
||||
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
|
||||
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
||||
|
||||
# Benchmarks
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NewType, Tuple
|
||||
from typing import Any, Callable, Dict, List, NewType, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
class DataCollator(ABC):
|
||||
"""
|
||||
A `DataCollator` is responsible for batching
|
||||
and pre-processing samples of data as requested by the training loop.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def collate_batch(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Take a list of samples from a Dataset and collate them into a batch.
|
||||
|
||||
Returns:
|
||||
A dictionary of tensors
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
InputDataClass = NewType("InputDataClass", Any)
|
||||
|
||||
"""
|
||||
A DataCollator is a function that takes a list of samples from a Dataset
|
||||
and collate them into a batch, as a dictionary of Tensors.
|
||||
"""
|
||||
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
|
||||
|
||||
@dataclass
|
||||
class DefaultDataCollator(DataCollator):
|
||||
|
||||
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Very simple data collator that:
|
||||
- simply collates batches of dict-like objects
|
||||
@@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator):
|
||||
See glue and ner for example of how it's useful.
|
||||
"""
|
||||
|
||||
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
||||
# In this method we'll make the assumption that all `features` in the batch
|
||||
# have the same attributes.
|
||||
# So we will look at the first element as a proxy for what attributes exist
|
||||
# on the whole batch.
|
||||
first = features[0]
|
||||
# In this function we'll make the assumption that all `features` in the batch
|
||||
# have the same attributes.
|
||||
# So we will look at the first element as a proxy for what attributes exist
|
||||
# on the whole batch.
|
||||
first = features[0]
|
||||
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if hasattr(first, "label") and first.label is not None:
|
||||
if type(first.label) is int:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
||||
if type(first.label_ids[0]) is int:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if hasattr(first, "label") and first.label is not None:
|
||||
if type(first.label) is int:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
else:
|
||||
batch = {}
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
||||
if type(first.label_ids[0]) is int:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
else:
|
||||
batch = {}
|
||||
|
||||
# Handling of all other possible attributes.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in vars(first).items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
||||
return batch
|
||||
# Handling of all other possible attributes.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in vars(first).items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForLanguageModeling(DataCollator):
|
||||
class DataCollatorForLanguageModeling:
|
||||
"""
|
||||
Data collator used for language modeling.
|
||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||
@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
batch = self._tensorize_batch(examples)
|
||||
if self.mlm:
|
||||
inputs, labels = self.mask_tokens(batch)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
||||
from tqdm.auto import tqdm, trange
|
||||
|
||||
from .data.data_collator import DataCollator, DefaultDataCollator
|
||||
from .data.data_collator import DataCollator, default_data_collator
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
|
||||
@@ -190,10 +190,7 @@ class Trainer:
|
||||
"""
|
||||
self.model = model.to(args.device)
|
||||
self.args = args
|
||||
if data_collator is not None:
|
||||
self.data_collator = data_collator
|
||||
else:
|
||||
self.data_collator = DefaultDataCollator()
|
||||
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
@@ -239,7 +236,7 @@ class Trainer:
|
||||
self.train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
@@ -264,7 +261,7 @@ class Trainer:
|
||||
eval_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
@@ -285,7 +282,7 @@ class Trainer:
|
||||
test_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ if is_torch_available():
|
||||
Trainer,
|
||||
LineByLineTextDataset,
|
||||
AutoModelForSequenceClassification,
|
||||
DefaultDataCollator,
|
||||
default_data_collator,
|
||||
DataCollatorForLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
||||
)
|
||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||
data_collator = DefaultDataCollator()
|
||||
batch = data_collator.collate_batch(dataset.features)
|
||||
data_collator = default_data_collator
|
||||
batch = data_collator(dataset.features)
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
|
||||
def test_default_regression(self):
|
||||
@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
||||
)
|
||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||
data_collator = DefaultDataCollator()
|
||||
batch = data_collator.collate_batch(dataset.features)
|
||||
data_collator = default_data_collator
|
||||
batch = data_collator(dataset.features)
|
||||
self.assertEqual(batch["labels"].dtype, torch.float)
|
||||
|
||||
def test_lm_tokenizer_without_padding(self):
|
||||
@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
with self.assertRaises(ValueError):
|
||||
# Expect error due to padding token missing on gpt2:
|
||||
data_collator.collate_batch(examples)
|
||||
data_collator(examples)
|
||||
|
||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator.collate_batch(examples)
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
|
||||
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator.collate_batch(examples)
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
|
||||
|
||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator.collate_batch(examples)
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
|
||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from transformers import DataCollator, Trainer
|
||||
from transformers import Trainer
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length: int = 101):
|
||||
@@ -41,8 +41,8 @@ if is_torch_available():
|
||||
def __getitem__(self, i) -> int:
|
||||
return i
|
||||
|
||||
class DummyDataCollator(DataCollator):
|
||||
def collate_batch(self, features):
|
||||
class DummyDataCollator:
|
||||
def __call__(self, features):
|
||||
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user