@@ -5,6 +5,7 @@ import torch
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
|
from ..tokenization_utils_base import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
InputDataClass = NewType("InputDataClass", Any)
|
InputDataClass = NewType("InputDataClass", Any)
|
||||||
@@ -33,7 +34,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||||||
# have the same attributes.
|
# have the same attributes.
|
||||||
# So we will look at the first element as a proxy for what attributes exist
|
# So we will look at the first element as a proxy for what attributes exist
|
||||||
# on the whole batch.
|
# on the whole batch.
|
||||||
if not isinstance(features[0], dict):
|
if not isinstance(features[0], (dict, BatchEncoding)):
|
||||||
features = [vars(f) for f in features]
|
features = [vars(f) for f in features]
|
||||||
|
|
||||||
first = features[0]
|
first = features[0]
|
||||||
@@ -78,7 +79,7 @@ class DataCollatorForLanguageModeling:
|
|||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
|
||||||
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
if isinstance(examples[0], dict):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
@@ -151,7 +152,7 @@ class DataCollatorForPermutationLanguageModeling:
|
|||||||
max_span_length: int = 5 # maximum length of a span of masked tokens
|
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||||
|
|
||||||
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
if isinstance(examples[0], dict):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user