Bug fix for permutation language modelling (#8409)
This commit is contained in:
@@ -579,7 +579,7 @@ class DataCollatorForPermutationLanguageModeling:
|
|||||||
masked_indices.masked_fill_(padding_mask, value=0.0)
|
masked_indices.masked_fill_(padding_mask, value=0.0)
|
||||||
|
|
||||||
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
||||||
non_func_mask = ~(padding_mask & special_tokens_mask)
|
non_func_mask = ~(padding_mask | special_tokens_mask)
|
||||||
|
|
||||||
inputs[masked_indices] = self.tokenizer.mask_token_id
|
inputs[masked_indices] = self.tokenizer.mask_token_id
|
||||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user