Make all data collators accept dict (#6065)
* Make all data collators accept dict * Style
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NewType, Tuple
|
||||
from typing import Any, Callable, Dict, List, NewType, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling:
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(examples[0], dict):
|
||||
examples = [e["input_ids"] for e in examples]
|
||||
batch = self._tensorize_batch(examples)
|
||||
if self.mlm:
|
||||
inputs, labels = self.mask_tokens(batch)
|
||||
@@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling:
|
||||
plm_probability: float = 1 / 6
|
||||
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||
|
||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(examples[0], dict):
|
||||
examples = [e["input_ids"] for e in examples]
|
||||
batch = self._tensorize_batch(examples)
|
||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||
|
||||
Reference in New Issue
Block a user