Make all data collators accept dict (#6065)

* Make all data collators accept dict

* Style
This commit is contained in:
Sylvain Gugger
2020-07-28 09:08:20 -04:00
committed by GitHub
parent 31a5486e42
commit 0206efb4cf

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Tuple
from typing import Any, Callable, Dict, List, NewType, Tuple, Union
import torch
from torch.nn.utils.rnn import pad_sequence
@@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling:
mlm: bool = True
mlm_probability: float = 0.15
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
if self.mlm:
inputs, labels = self.mask_tokens(batch)
@@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling:
plm_probability: float = 1 / 6
max_span_length: int = 5 # maximum length of a span of masked tokens
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}