Fixed DataCollatorForLanguageModeling not accepting lists of lists (#6685)
* Fixed DataCollatorForLanguageModeling + PermutationLanguageModeling not accepting lists of lists * Update data_collator.py * black was grumpy
This commit is contained in:
@@ -128,7 +128,9 @@ class DataCollatorForLanguageModeling:
|
|||||||
mlm: bool = True
|
mlm: bool = True
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
|
||||||
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
def __call__(
|
||||||
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
@@ -141,7 +143,12 @@ class DataCollatorForLanguageModeling:
|
|||||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||||
return {"input_ids": batch, "labels": labels}
|
return {"input_ids": batch, "labels": labels}
|
||||||
|
|
||||||
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
def _tensorize_batch(
|
||||||
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# In order to accept both lists of lists and lists of Tensors
|
||||||
|
if isinstance(examples[0], (list, tuple)):
|
||||||
|
examples = [torch.Tensor(e) for e in examples]
|
||||||
length_of_first = examples[0].size(0)
|
length_of_first = examples[0].size(0)
|
||||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||||
if are_tensors_same_length:
|
if are_tensors_same_length:
|
||||||
@@ -202,14 +209,21 @@ class DataCollatorForPermutationLanguageModeling:
|
|||||||
plm_probability: float = 1 / 6
|
plm_probability: float = 1 / 6
|
||||||
max_span_length: int = 5 # maximum length of a span of masked tokens
|
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||||
|
|
||||||
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
def __call__(
|
||||||
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||||
|
|
||||||
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
def _tensorize_batch(
|
||||||
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# In order to accept both lists of lists and lists of Tensors
|
||||||
|
if isinstance(examples[0], (list, tuple)):
|
||||||
|
examples = [torch.Tensor(e) for e in examples]
|
||||||
length_of_first = examples[0].size(0)
|
length_of_first = examples[0].size(0)
|
||||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||||
if are_tensors_same_length:
|
if are_tensors_same_length:
|
||||||
|
|||||||
Reference in New Issue
Block a user