@@ -434,13 +434,15 @@ class DataCollatorForNextSentencePrediction:
|
|||||||
else:
|
else:
|
||||||
input_ids = self._tensorize_batch(input_ids)
|
input_ids = self._tensorize_batch(input_ids)
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": self._tensorize_batch(attention_masks),
|
"attention_mask": self._tensorize_batch(attention_masks),
|
||||||
"token_type_ids": self._tensorize_batch(segment_ids),
|
"token_type_ids": self._tensorize_batch(segment_ids),
|
||||||
"masked_lm_labels": mlm_labels if self.mlm else None,
|
|
||||||
"next_sentence_label": torch.tensor(nsp_labels),
|
"next_sentence_label": torch.tensor(nsp_labels),
|
||||||
}
|
}
|
||||||
|
if self.mlm:
|
||||||
|
result["masked_lm_labels"] = mlm_labels
|
||||||
|
return result
|
||||||
|
|
||||||
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
||||||
length_of_first = examples[0].size(0)
|
length_of_first = examples[0].size(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user