Make DataCollator a callable (#5015)
* Make DataCollator a callable * Update src/transformers/data/data_collator.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -38,7 +38,6 @@ from transformers import (
|
|||||||
BertConfig,
|
BertConfig,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
DefaultDataCollator,
|
|
||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
DistilBertForSequenceClassification,
|
DistilBertForSequenceClassification,
|
||||||
DistilBertTokenizer,
|
DistilBertTokenizer,
|
||||||
@@ -51,6 +50,7 @@ from transformers import (
|
|||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
|
default_data_collator,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from utils_hans import HansDataset, hans_output_modes, hans_processors
|
from utils_hans import HansDataset, hans_output_modes, hans_processors
|
||||||
@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
|
||||||
sampler=train_sampler,
|
|
||||||
batch_size=args.train_batch_size,
|
|
||||||
collate_fn=DefaultDataCollator().collate_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
|
|||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
eval_sampler = SequentialSampler(eval_dataset)
|
eval_sampler = SequentialSampler(eval_dataset)
|
||||||
eval_dataloader = DataLoader(
|
eval_dataloader = DataLoader(
|
||||||
eval_dataset,
|
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
|
||||||
sampler=eval_sampler,
|
|
||||||
batch_size=args.eval_batch_size,
|
|
||||||
collate_fn=DefaultDataCollator().collate_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DefaultDataCollator,
|
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
|
default_data_collator,
|
||||||
glue_compute_metrics,
|
glue_compute_metrics,
|
||||||
glue_output_modes,
|
glue_output_modes,
|
||||||
glue_processors,
|
glue_processors,
|
||||||
@@ -424,7 +424,7 @@ def main():
|
|||||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||||
eval_dataloader = DataLoader(
|
eval_dataloader = DataLoader(
|
||||||
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch
|
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute head entropy and importance score
|
# Compute head entropy and importance score
|
||||||
|
|||||||
@@ -364,7 +364,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
|
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
|
||||||
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
|
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
|
||||||
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
||||||
|
|
||||||
# Benchmarks
|
# Benchmarks
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NewType, Tuple
|
from typing import Any, Callable, Dict, List, NewType, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
|
|||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
class DataCollator(ABC):
|
|
||||||
"""
|
|
||||||
A `DataCollator` is responsible for batching
|
|
||||||
and pre-processing samples of data as requested by the training loop.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collate_batch(self) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Take a list of samples from a Dataset and collate them into a batch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of tensors
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
InputDataClass = NewType("InputDataClass", Any)
|
InputDataClass = NewType("InputDataClass", Any)
|
||||||
|
|
||||||
|
"""
|
||||||
|
A DataCollator is a function that takes a list of samples from a Dataset
|
||||||
|
and collate them into a batch, as a dictionary of Tensors.
|
||||||
|
"""
|
||||||
|
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DefaultDataCollator(DataCollator):
|
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Very simple data collator that:
|
Very simple data collator that:
|
||||||
- simply collates batches of dict-like objects
|
- simply collates batches of dict-like objects
|
||||||
@@ -42,8 +29,7 @@ class DefaultDataCollator(DataCollator):
|
|||||||
See glue and ner for example of how it's useful.
|
See glue and ner for example of how it's useful.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
# In this function we'll make the assumption that all `features` in the batch
|
||||||
# In this method we'll make the assumption that all `features` in the batch
|
|
||||||
# have the same attributes.
|
# have the same attributes.
|
||||||
# So we will look at the first element as a proxy for what attributes exist
|
# So we will look at the first element as a proxy for what attributes exist
|
||||||
# on the whole batch.
|
# on the whole batch.
|
||||||
@@ -76,7 +62,7 @@ class DefaultDataCollator(DataCollator):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForLanguageModeling(DataCollator):
|
class DataCollatorForLanguageModeling:
|
||||||
"""
|
"""
|
||||||
Data collator used for language modeling.
|
Data collator used for language modeling.
|
||||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||||
@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
|
|||||||
mlm: bool = True
|
mlm: bool = True
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
|
||||||
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
inputs, labels = self.mask_tokens(batch)
|
inputs, labels = self.mask_tokens(batch)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
||||||
from tqdm.auto import tqdm, trange
|
from tqdm.auto import tqdm, trange
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DefaultDataCollator
|
from .data.data_collator import DataCollator, default_data_collator
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
|
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
|
||||||
@@ -190,10 +190,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
self.model = model.to(args.device)
|
self.model = model.to(args.device)
|
||||||
self.args = args
|
self.args = args
|
||||||
if data_collator is not None:
|
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
||||||
self.data_collator = data_collator
|
|
||||||
else:
|
|
||||||
self.data_collator = DefaultDataCollator()
|
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
@@ -239,7 +236,7 @@ class Trainer:
|
|||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size=self.args.train_batch_size,
|
batch_size=self.args.train_batch_size,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
collate_fn=self.data_collator.collate_batch,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -264,7 +261,7 @@ class Trainer:
|
|||||||
eval_dataset,
|
eval_dataset,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator.collate_batch,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -285,7 +282,7 @@ class Trainer:
|
|||||||
test_dataset,
|
test_dataset,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator.collate_batch,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ if is_torch_available():
|
|||||||
Trainer,
|
Trainer,
|
||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
DefaultDataCollator,
|
default_data_collator,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
||||||
)
|
)
|
||||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||||
data_collator = DefaultDataCollator()
|
data_collator = default_data_collator
|
||||||
batch = data_collator.collate_batch(dataset.features)
|
batch = data_collator(dataset.features)
|
||||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
|
|
||||||
def test_default_regression(self):
|
def test_default_regression(self):
|
||||||
@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
||||||
)
|
)
|
||||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||||
data_collator = DefaultDataCollator()
|
data_collator = default_data_collator
|
||||||
batch = data_collator.collate_batch(dataset.features)
|
batch = data_collator(dataset.features)
|
||||||
self.assertEqual(batch["labels"].dtype, torch.float)
|
self.assertEqual(batch["labels"].dtype, torch.float)
|
||||||
|
|
||||||
def test_lm_tokenizer_without_padding(self):
|
def test_lm_tokenizer_without_padding(self):
|
||||||
@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
examples = [dataset[i] for i in range(len(dataset))]
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# Expect error due to padding token missing on gpt2:
|
# Expect error due to padding token missing on gpt2:
|
||||||
data_collator.collate_batch(examples)
|
data_collator(examples)
|
||||||
|
|
||||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||||
examples = [dataset[i] for i in range(len(dataset))]
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
batch = data_collator.collate_batch(examples)
|
batch = data_collator(examples)
|
||||||
self.assertIsInstance(batch, dict)
|
self.assertIsInstance(batch, dict)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||||
@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||||
examples = [dataset[i] for i in range(len(dataset))]
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
batch = data_collator.collate_batch(examples)
|
batch = data_collator(examples)
|
||||||
self.assertIsInstance(batch, dict)
|
self.assertIsInstance(batch, dict)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
|
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
|
||||||
|
|
||||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||||
examples = [dataset[i] for i in range(len(dataset))]
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
batch = data_collator.collate_batch(examples)
|
batch = data_collator(examples)
|
||||||
self.assertIsInstance(batch, dict)
|
self.assertIsInstance(batch, dict)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
from transformers import DataCollator, Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
class DummyDataset(Dataset):
|
class DummyDataset(Dataset):
|
||||||
def __init__(self, length: int = 101):
|
def __init__(self, length: int = 101):
|
||||||
@@ -41,8 +41,8 @@ if is_torch_available():
|
|||||||
def __getitem__(self, i) -> int:
|
def __getitem__(self, i) -> int:
|
||||||
return i
|
return i
|
||||||
|
|
||||||
class DummyDataCollator(DataCollator):
|
class DummyDataCollator:
|
||||||
def collate_batch(self, features):
|
def __call__(self, features):
|
||||||
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
|
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
|
||||||
|
|
||||||
class DummyModel(nn.Module):
|
class DummyModel(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user