Make all data collators accept dict (#6065)

* Make all data collators accept dict

* Style
This commit is contained in:
Sylvain Gugger
2020-07-28 09:08:20 -04:00
committed by GitHub
parent 31a5486e42
commit 0206efb4cf

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Tuple from typing import Any, Callable, Dict, List, NewType, Tuple, Union
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling:
mlm: bool = True mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
if self.mlm: if self.mlm:
inputs, labels = self.mask_tokens(batch) inputs, labels = self.mask_tokens(batch)
@@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling:
plm_probability: float = 1 / 6 plm_probability: float = 1 / 6
max_span_length: int = 5 # maximum length of a span of masked tokens max_span_length: int = 5 # maximum length of a span of masked tokens
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}