Bug fix for permutation language modelling (#8409)
This commit is contained in:
@@ -579,7 +579,7 @@ class DataCollatorForPermutationLanguageModeling:
|
||||
masked_indices.masked_fill_(padding_mask, value=0.0)
|
||||
|
||||
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
||||
non_func_mask = ~(padding_mask & special_tokens_mask)
|
||||
non_func_mask = ~(padding_mask | special_tokens_mask)
|
||||
|
||||
inputs[masked_indices] = self.tokenizer.mask_token_id
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
Reference in New Issue
Block a user