Bug fix for permutation language modelling (#8409)

This commit is contained in:
Shashank Gupta
2020-11-09 20:53:26 +05:30
committed by GitHub
parent bf8625e70b
commit 1e2acd0dcf

View File

@@ -579,7 +579,7 @@ class DataCollatorForPermutationLanguageModeling:
masked_indices.masked_fill_(padding_mask, value=0.0) masked_indices.masked_fill_(padding_mask, value=0.0)
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
non_func_mask = ~(padding_mask & special_tokens_mask) non_func_mask = ~(padding_mask | special_tokens_mask)
inputs[masked_indices] = self.tokenizer.mask_token_id inputs[masked_indices] = self.tokenizer.mask_token_id
labels[~masked_indices] = -100 # We only compute loss on masked tokens labels[~masked_indices] = -100 # We only compute loss on masked tokens