Fix DataCollatorForWholeWordMask (#8379)
* Fix DataCollatorForWholeWordMask * Replace all tensorize_batch in data_collator.py
This commit is contained in:
@@ -315,7 +315,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
input_ids = examples
|
||||
examples = [{"input_ids": e} for e in examples]
|
||||
|
||||
batch_input = self._tensorize_batch(input_ids)
|
||||
batch_input = _collate_batch(input_ids, self.tokenizer)
|
||||
|
||||
mask_labels = []
|
||||
for e in examples:
|
||||
@@ -332,7 +332,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
if i in ref_pos:
|
||||
ref_tokens[i] = "##" + ref_tokens[i]
|
||||
mask_labels.append(self._whole_word_mask(ref_tokens))
|
||||
batch_mask = self._tensorize_batch(mask_labels)
|
||||
batch_mask = _collate_batch(mask_labels, self.tokenizer)
|
||||
inputs, labels = self.mask_tokens(batch_input, batch_mask)
|
||||
return {"input_ids": inputs, "labels": labels}
|
||||
|
||||
@@ -511,28 +511,10 @@ class DataCollatorForPermutationLanguageModeling:
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||
examples = [e["input_ids"] for e in examples]
|
||||
batch = self._tensorize_batch(examples)
|
||||
batch = _collate_batch(examples, self.tokenizer)
|
||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||
|
||||
def _tensorize_batch(
|
||||
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
||||
) -> torch.Tensor:
|
||||
# In order to accept both lists of lists and lists of Tensors
|
||||
if isinstance(examples[0], (list, tuple)):
|
||||
examples = [torch.Tensor(e) for e in examples]
|
||||
length_of_first = examples[0].size(0)
|
||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||
if are_tensors_same_length:
|
||||
return torch.stack(examples, dim=0)
|
||||
else:
|
||||
if self.tokenizer._pad_token is None:
|
||||
raise ValueError(
|
||||
"You are attempting to pad samples but the tokenizer you are using"
|
||||
f" ({self.tokenizer.__class__.__name__}) does not have one."
|
||||
)
|
||||
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
||||
|
||||
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
||||
|
||||
Reference in New Issue
Block a user