From d329c9b05dcf2c6a479a135620aaddbf291e8a05 Mon Sep 17 00:00:00 2001 From: Teven Date: Mon, 24 Aug 2020 15:31:44 +0200 Subject: [PATCH] Fixed DataCollatorForLanguageModeling not accepting lists of lists (#6685) * Fixed DataCollatorForLanguageModeling + PermutationLanguageModeling not accepting lists of lists * Update data_collator.py * black was grumpy --- src/transformers/data/data_collator.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 11b8535096..b14d06d4fb 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -128,7 +128,9 @@ class DataCollatorForLanguageModeling: mlm: bool = True mlm_probability: float = 0.15 - def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + def __call__( + self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] + ) -> Dict[str, torch.Tensor]: if isinstance(examples[0], (dict, BatchEncoding)): examples = [e["input_ids"] for e in examples] batch = self._tensorize_batch(examples) @@ -141,7 +143,12 @@ class DataCollatorForLanguageModeling: labels[labels == self.tokenizer.pad_token_id] = -100 return {"input_ids": batch, "labels": labels} - def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: + def _tensorize_batch( + self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] + ) -> torch.Tensor: + # In order to accept both lists of lists and lists of Tensors + if isinstance(examples[0], (list, tuple)): + examples = [torch.Tensor(e) for e in examples] length_of_first = examples[0].size(0) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) if are_tensors_same_length: @@ -202,14 +209,21 @@ class DataCollatorForPermutationLanguageModeling: plm_probability: float = 1 / 6 max_span_length: int = 5 # maximum length of a span of masked tokens - def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + def __call__( + self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] + ) -> Dict[str, torch.Tensor]: if isinstance(examples[0], (dict, BatchEncoding)): examples = [e["input_ids"] for e in examples] batch = self._tensorize_batch(examples) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} - def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: + def _tensorize_batch( + self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] + ) -> torch.Tensor: + # In order to accept both lists of lists and lists of Tensors + if isinstance(examples[0], (list, tuple)): + examples = [torch.Tensor(e) for e in examples] length_of_first = examples[0].size(0) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) if are_tensors_same_length: