prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
@@ -11,6 +12,7 @@ from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
@@ -22,6 +24,7 @@ try:
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
freeze_params,
|
||||
label_smoothed_nll_loss,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
@@ -34,12 +37,15 @@ except ImportError:
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
freeze_params,
|
||||
label_smoothed_nll_loss,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
||||
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
|
||||
def __init__(self, hparams):
|
||||
@@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
def _step(self, batch):
|
||||
# assert is_frozen(self.teacher)
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = y[:, :-1].contiguous()
|
||||
labels = y[:, 1:].clone()
|
||||
labels[y[:, 1:] == pad_token_id] = -100
|
||||
input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
||||
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
||||
# noinspection PyCallingNonCallable
|
||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
use_cache=False,
|
||||
) # TODO(@sshleifer): return_dict=True cleanup
|
||||
|
||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
|
||||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(sloss)
|
||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||
|
||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||
if self.different_encoder:
|
||||
@@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
attention_mask=src_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
lm_labels=tgt_ids,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * sloss
|
||||
+ self.alpha_mlm * student_lm_loss
|
||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
|
||||
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
||||
assert not isinstance(
|
||||
hidden_states, torch.Tensor
|
||||
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
|
||||
assert not isinstance(
|
||||
hidden_states_T, torch.Tensor
|
||||
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
|
||||
msg = "expected list or tuple for hidden_states, got tensor of shape: "
|
||||
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
|
||||
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
||||
mask = attention_mask.to(hidden_states[0])
|
||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||
hidden_losses = [
|
||||
@@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
|
||||
|
||||
def add_distill_args(parser):
|
||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||
parser.add_argument("--teacher", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
@@ -245,8 +258,9 @@ def add_distill_args(parser):
|
||||
|
||||
|
||||
class BartTranslationDistiller(BartSummarizationDistiller):
|
||||
"""Supports Mbart, Marian, other models that inherit from Bart."""
|
||||
|
||||
mode = "translation"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
@@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
attention_mask=source_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
@@ -402,6 +416,7 @@ def create_module(args):
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||
# TODO(SS): DELETE?
|
||||
exp_dir = ckpt_path.parent
|
||||
if dest_dir is None:
|
||||
dest_dir = exp_dir
|
||||
@@ -424,33 +439,40 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
def get_layers_to_copy(n_to_get, tot):
|
||||
all_layers = list(range(tot))
|
||||
if tot == 12: # Alternating for special cases
|
||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 6],
|
||||
3: [0, 6, 11],
|
||||
4: [0, 4, 8, 11],
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||
12: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
elif tot == 16:
|
||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 8],
|
||||
3: [0, 8, 15],
|
||||
4: [0, 5, 10, 15],
|
||||
6: [0, 3, 6, 9, 12, 15],
|
||||
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
||||
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
||||
16: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
else:
|
||||
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
|
||||
LAYERS_TO_COPY = {
|
||||
# maps num layers in student -> which teacher layers to copy.
|
||||
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
|
||||
12: {
|
||||
1: [0],
|
||||
2: [0, 6],
|
||||
3: [0, 6, 11],
|
||||
4: [0, 4, 8, 11],
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||
12: list(range(12)),
|
||||
},
|
||||
16: { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 8],
|
||||
3: [0, 8, 15],
|
||||
4: [0, 5, 10, 15],
|
||||
6: [0, 3, 6, 9, 12, 15],
|
||||
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
||||
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
||||
16: list(range(16)),
|
||||
},
|
||||
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
||||
}
|
||||
|
||||
|
||||
def get_layers_to_copy(n_student, n_teacher):
|
||||
try:
|
||||
return LAYERS_TO_COPY[n_teacher][n_student]
|
||||
except KeyError:
|
||||
warnings.warn(
|
||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||
)
|
||||
return list(range(n_student))
|
||||
|
||||
|
||||
def distill_main(args):
|
||||
|
||||
Reference in New Issue
Block a user