From 0206efb4cfcffd9d1cf349b098892cc49c9a3efc Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 28 Jul 2020 09:08:20 -0400 Subject: [PATCH] Make all data collators accept dict (#6065) * Make all data collators accept dict * Style --- src/transformers/data/data_collator.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index b4d9f205b9..29d7bf43a2 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict, List, NewType, Tuple +from typing import Any, Callable, Dict, List, NewType, Tuple, Union import torch from torch.nn.utils.rnn import pad_sequence @@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling: mlm: bool = True mlm_probability: float = 0.15 - def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + if isinstance(examples[0], dict): + examples = [e["input_ids"] for e in examples] batch = self._tensorize_batch(examples) if self.mlm: inputs, labels = self.mask_tokens(batch) @@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling: plm_probability: float = 1 / 6 max_span_length: int = 5 # maximum length of a span of masked tokens - def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + if isinstance(examples[0], dict): + examples = [e["input_ids"] for e in examples] batch = self._tensorize_batch(examples) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}