Make all data collators accept dict (#6065)
* Make all data collators accept dict * Style
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, NewType, Tuple
|
from typing import Any, Callable, Dict, List, NewType, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
@@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling:
|
|||||||
mlm: bool = True
|
mlm: bool = True
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
|
||||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
if isinstance(examples[0], dict):
|
||||||
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
inputs, labels = self.mask_tokens(batch)
|
inputs, labels = self.mask_tokens(batch)
|
||||||
@@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling:
|
|||||||
plm_probability: float = 1 / 6
|
plm_probability: float = 1 / 6
|
||||||
max_span_length: int = 5 # maximum length of a span of masked tokens
|
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||||
|
|
||||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
if isinstance(examples[0], dict):
|
||||||
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||||
|
|||||||
Reference in New Issue
Block a user