diff --git a/examples/adversarial/test_hans.py b/examples/adversarial/test_hans.py index 3d8cf08598..91298811b8 100644 --- a/examples/adversarial/test_hans.py +++ b/examples/adversarial/test_hans.py @@ -38,7 +38,6 @@ from transformers import ( BertConfig, BertForSequenceClassification, BertTokenizer, - DefaultDataCollator, DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer, @@ -51,6 +50,7 @@ from transformers import ( XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer, + default_data_collator, get_linear_schedule_with_warmup, ) from utils_hans import HansDataset, hans_output_modes, hans_processors @@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_dataloader = DataLoader( - train_dataset, - sampler=train_sampler, - batch_size=args.train_batch_size, - collate_fn=DefaultDataCollator().collate_batch, + train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator, ) if args.max_steps > 0: @@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""): # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader( - eval_dataset, - sampler=eval_sampler, - batch_size=args.eval_batch_size, - collate_fn=DefaultDataCollator().collate_batch, + eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator, ) # multi-gpu eval diff --git a/examples/bertology/run_bertology.py b/examples/bertology/run_bertology.py index 1d498b8646..990354b894 100644 --- a/examples/bertology/run_bertology.py +++ b/examples/bertology/run_bertology.py @@ -34,8 +34,8 @@ from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, - DefaultDataCollator, GlueDataset, + default_data_collator, glue_compute_metrics, glue_output_modes, glue_processors, @@ -424,7 +424,7 @@ def main(): eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset))))) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader( - eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch + eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator ) # Compute head entropy and importance score diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b4c21a6b0c..670a7feca7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -364,7 +364,7 @@ if is_torch_available(): # Trainer from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction - from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling + from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments # Benchmarks diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 7cd095651c..629a8b0a6e 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,6 +1,5 @@ -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, NewType, Tuple +from typing import Any, Callable, Dict, List, NewType, Tuple import torch from torch.nn.utils.rnn import pad_sequence @@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence from ..tokenization_utils import PreTrainedTokenizer -class DataCollator(ABC): - """ - A `DataCollator` is responsible for batching - and pre-processing samples of data as requested by the training loop. - """ - - @abstractmethod - def collate_batch(self) -> Dict[str, torch.Tensor]: - """ - Take a list of samples from a Dataset and collate them into a batch. - - Returns: - A dictionary of tensors - """ - pass - - InputDataClass = NewType("InputDataClass", Any) +""" +A DataCollator is a function that takes a list of samples from a Dataset +and collate them into a batch, as a dictionary of Tensors. +""" +DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]]) -@dataclass -class DefaultDataCollator(DataCollator): + +def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]: """ Very simple data collator that: - simply collates batches of dict-like objects @@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator): See glue and ner for example of how it's useful. """ - def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]: - # In this method we'll make the assumption that all `features` in the batch - # have the same attributes. - # So we will look at the first element as a proxy for what attributes exist - # on the whole batch. - first = features[0] + # In this function we'll make the assumption that all `features` in the batch + # have the same attributes. + # So we will look at the first element as a proxy for what attributes exist + # on the whole batch. + first = features[0] - # Special handling for labels. - # Ensure that tensor is created with the correct type - # (it should be automatically the case, but let's make sure of it.) - if hasattr(first, "label") and first.label is not None: - if type(first.label) is int: - labels = torch.tensor([f.label for f in features], dtype=torch.long) - else: - labels = torch.tensor([f.label for f in features], dtype=torch.float) - batch = {"labels": labels} - elif hasattr(first, "label_ids") and first.label_ids is not None: - if type(first.label_ids[0]) is int: - labels = torch.tensor([f.label_ids for f in features], dtype=torch.long) - else: - labels = torch.tensor([f.label_ids for f in features], dtype=torch.float) - batch = {"labels": labels} + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if hasattr(first, "label") and first.label is not None: + if type(first.label) is int: + labels = torch.tensor([f.label for f in features], dtype=torch.long) else: - batch = {} + labels = torch.tensor([f.label for f in features], dtype=torch.float) + batch = {"labels": labels} + elif hasattr(first, "label_ids") and first.label_ids is not None: + if type(first.label_ids[0]) is int: + labels = torch.tensor([f.label_ids for f in features], dtype=torch.long) + else: + labels = torch.tensor([f.label_ids for f in features], dtype=torch.float) + batch = {"labels": labels} + else: + batch = {} - # Handling of all other possible attributes. - # Again, we will use the first element to figure out which key/values are not None for this model. - for k, v in vars(first).items(): - if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): - batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) - return batch + # Handling of all other possible attributes. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in vars(first).items(): + if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): + batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) + return batch @dataclass -class DataCollatorForLanguageModeling(DataCollator): +class DataCollatorForLanguageModeling: """ Data collator used for language modeling. - collates batches of tensors, honoring their tokenizer's pad_token @@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator): mlm: bool = True mlm_probability: float = 0.15 - def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: batch = self._tensorize_batch(examples) if self.mlm: inputs, labels = self.mask_tokens(batch) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7eeeff6b32..6a4e59b9bf 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from tqdm.auto import tqdm, trange -from .data.data_collator import DataCollator, DefaultDataCollator +from .data.data_collator import DataCollator, default_data_collator from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput @@ -190,10 +190,7 @@ class Trainer: """ self.model = model.to(args.device) self.args = args - if data_collator is not None: - self.data_collator = data_collator - else: - self.data_collator = DefaultDataCollator() + self.data_collator = data_collator if data_collator is not None else default_data_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.compute_metrics = compute_metrics @@ -239,7 +236,7 @@ class Trainer: self.train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, - collate_fn=self.data_collator.collate_batch, + collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, ) @@ -264,7 +261,7 @@ class Trainer: eval_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, - collate_fn=self.data_collator.collate_batch, + collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, ) @@ -285,7 +282,7 @@ class Trainer: test_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, - collate_fn=self.data_collator.collate_batch, + collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, ) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1717030376..89d6a77918 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -11,7 +11,7 @@ if is_torch_available(): Trainer, LineByLineTextDataset, AutoModelForSequenceClassification, - DefaultDataCollator, + default_data_collator, DataCollatorForLanguageModeling, GlueDataset, GlueDataTrainingArguments, @@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase): task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True ) dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - data_collator = DefaultDataCollator() - batch = data_collator.collate_batch(dataset.features) + data_collator = default_data_collator + batch = data_collator(dataset.features) self.assertEqual(batch["labels"].dtype, torch.long) def test_default_regression(self): @@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase): task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True ) dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - data_collator = DefaultDataCollator() - batch = data_collator.collate_batch(dataset.features) + data_collator = default_data_collator + batch = data_collator(dataset.features) self.assertEqual(batch["labels"].dtype, torch.float) def test_lm_tokenizer_without_padding(self): @@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase): examples = [dataset[i] for i in range(len(dataset))] with self.assertRaises(ValueError): # Expect error due to padding token missing on gpt2: - data_collator.collate_batch(examples) + data_collator(examples) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator.collate_batch(examples) + batch = data_collator(examples) self.assertIsInstance(batch, dict) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) @@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator.collate_batch(examples) + batch = data_collator(examples) self.assertIsInstance(batch, dict) self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107))) self.assertEqual(batch["labels"].shape, torch.Size((31, 107))) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator.collate_batch(examples) + batch = data_collator(examples) self.assertIsInstance(batch, dict) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py index 7111540190..3836930544 100644 --- a/tests/test_trainer_distributed.py +++ b/tests/test_trainer_distributed.py @@ -29,7 +29,7 @@ if is_torch_available(): from torch import nn from torch.utils.data.dataset import Dataset - from transformers import DataCollator, Trainer + from transformers import Trainer class DummyDataset(Dataset): def __init__(self, length: int = 101): @@ -41,8 +41,8 @@ if is_torch_available(): def __getitem__(self, i) -> int: return i - class DummyDataCollator(DataCollator): - def collate_batch(self, features): + class DummyDataCollator: + def __call__(self, features): return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)} class DummyModel(nn.Module):