Fix doc errors and typos across the board (#8139)
* Fix doc errors and typos across the board * Fix a typo * Fix the CI * Fix more typos * Fix CI * More fixes * Fix CI * More fixes * More fixes
This commit is contained in:
@@ -265,7 +265,7 @@ class Distiller:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
@@ -401,9 +401,9 @@ class Distiller:
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
if self.params.restrict_ce_to_mask:
|
||||
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
|
||||
@@ -61,7 +61,7 @@ class LmSeqsDataset(Dataset):
|
||||
|
||||
def remove_long_sequences(self):
|
||||
"""
|
||||
Sequences that are too long are splitted by chunk of max_model_input_size.
|
||||
Sequences that are too long are split by chunk of max_model_input_size.
|
||||
"""
|
||||
max_len = self.params.max_model_input_size
|
||||
indices = self.lengths > max_len
|
||||
@@ -138,8 +138,8 @@ class LmSeqsDataset(Dataset):
|
||||
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||
|
||||
# unk_idx = self.params.special_tok_ids['unk_token']
|
||||
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
|
||||
# nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||
# logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
|
||||
|
||||
def batch_sequences(self, batch):
|
||||
"""
|
||||
|
||||
@@ -96,7 +96,7 @@ if __name__ == "__main__":
|
||||
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
|
||||
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
|
||||
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
|
||||
Reference in New Issue
Block a user