double check cc @LysandreJik
This commit is contained in:
@@ -75,7 +75,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
|
|||||||
n_batch = len(dataset)
|
n_batch = len(dataset)
|
||||||
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
|
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
|
||||||
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
|
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
|
||||||
lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64)
|
lm_labels = np.full((n_batch, 2, input_len), fill_value=-100, dtype=np.int64)
|
||||||
mc_labels = np.zeros((n_batch,), dtype=np.int64)
|
mc_labels = np.zeros((n_batch,), dtype=np.int64)
|
||||||
for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
|
for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
|
||||||
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
|
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class Distiller:
|
|||||||
-------
|
-------
|
||||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
|
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
|
||||||
"""
|
"""
|
||||||
token_ids, lengths = batch
|
token_ids, lengths = batch
|
||||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||||
@@ -246,7 +246,7 @@ class Distiller:
|
|||||||
-------
|
-------
|
||||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
|
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
|
||||||
"""
|
"""
|
||||||
token_ids, lengths = batch
|
token_ids, lengths = batch
|
||||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||||
|
|||||||
Reference in New Issue
Block a user