double check cc @LysandreJik
This commit is contained in:
@@ -186,7 +186,7 @@ class Distiller:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
@@ -246,7 +246,7 @@ class Distiller:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
|
||||
Reference in New Issue
Block a user