Support T5 Distillation w/hidden state supervision (#7599)
This commit is contained in:
@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
|
|||||||
class BartSummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
||||||
|
|
||||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
assert Path(hparams.data_dir).exists()
|
assert Path(hparams.data_dir).exists()
|
||||||
@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
if hparams.length_penalty != -1:
|
if hparams.length_penalty != -1:
|
||||||
student.config.length_penalty = hparams.length_penalty
|
student.config.length_penalty = hparams.length_penalty
|
||||||
super().__init__(hparams, model=student, config=student.config)
|
super().__init__(hparams, model=student, config=student.config)
|
||||||
|
model_type = student.config.model_type
|
||||||
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
||||||
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
|
|
||||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
if model_type == "t5":
|
||||||
|
teacher_encoder_layers = len(teacher.get_encoder().block)
|
||||||
|
teacher_decoder_layers = len(teacher.get_decoder().block)
|
||||||
|
else:
|
||||||
|
teacher_encoder_layers = teacher.config.encoder_layers
|
||||||
|
teacher_decoder_layers = teacher.config.decoder_layers
|
||||||
|
|
||||||
|
self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
|
||||||
|
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
|
||||||
|
|
||||||
self.teacher = teacher
|
self.teacher = teacher
|
||||||
freeze_params(self.teacher)
|
freeze_params(self.teacher)
|
||||||
|
|
||||||
@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
del self.teacher.encoder
|
del self.teacher.encoder
|
||||||
# Intermediate supervision: Decide which layers to supervise
|
# Intermediate supervision: Decide which layers to supervise
|
||||||
if hparams.supervise_forward:
|
if hparams.supervise_forward:
|
||||||
self.d_matches = get_layers_to_supervise(
|
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
|
||||||
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
|
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
|
||||||
)
|
else: # student layer should emulate hidden states of the teacher layer it was copied from
|
||||||
else:
|
self.e_matches = self.e_layer_ids
|
||||||
self.d_matches = self.d_layer_ids
|
self.d_matches = self.d_layer_ids
|
||||||
|
|
||||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
self.temperature = 2.0
|
self.temperature = 2.0
|
||||||
self.alpha_mlm = hparams.alpha_mlm
|
self.alpha_mlm = hparams.alpha_mlm
|
||||||
self.alpha_ce = hparams.alpha_ce
|
self.alpha_ce = hparams.alpha_ce
|
||||||
self.alpha_hid = hparams.alpha_hid
|
self.alpha_hid = hparams.alpha_hid
|
||||||
self.alpha_encoder_loss = hparams.alpha_encoder_loss
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
) # TODO(@sshleifer): return_dict=True cleanup
|
)
|
||||||
|
|
||||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
def zero_tensor():
|
def zero_tensor():
|
||||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||||
|
|
||||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
||||||
if self.different_encoder:
|
if self.different_encoder: # compute encoder hidden state loss
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
|
teacher_enc_hid = self.teacher.get_encoder()(
|
||||||
input_ids, attention_mask=src_mask, output_hidden_states=True
|
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
|
||||||
)
|
).hidden_states
|
||||||
# DEPRECATE THIS
|
|
||||||
if self.hparams.alpha_encoder_loss > 0:
|
|
||||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
|
||||||
|
|
||||||
hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
|
hid_loss_enc = self.calc_hidden_loss(
|
||||||
|
src_mask,
|
||||||
teacher_enc_outputs = (enc_outputs,)
|
enc_hidden_state,
|
||||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
teacher_enc_hid,
|
||||||
|
self.e_matches,
|
||||||
|
normalize_hidden=self.hparams.normalize_hidden,
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
outputs = self.teacher(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=src_mask,
|
attention_mask=src_mask,
|
||||||
encoder_outputs=teacher_enc_outputs,
|
encoder_outputs=(enc_outputs,),
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
lm_labels=labels,
|
lm_labels=labels,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
|
||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||||
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
|
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
|
||||||
@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
blended_loss = (
|
blended_loss = (
|
||||||
self.alpha_ce * loss_ce
|
self.alpha_ce * loss_ce
|
||||||
+ self.alpha_mlm * student_lm_loss
|
+ self.alpha_mlm * student_lm_loss
|
||||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
|
||||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||||
)
|
)
|
||||||
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
|
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
|
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
|
||||||
@@ -207,7 +218,6 @@ def add_distill_args(parser):
|
|||||||
parser.add_argument("--teacher", type=str)
|
parser.add_argument("--teacher", type=str)
|
||||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
|
||||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||||
|
|||||||
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
|
|||||||
"n_val": -1,
|
"n_val": -1,
|
||||||
"n_test": -1,
|
"n_test": -1,
|
||||||
"student_encoder_layers": 1,
|
"student_encoder_layers": 1,
|
||||||
"alpha_encoder_loss": 0.0,
|
|
||||||
"freeze_encoder": False,
|
"freeze_encoder": False,
|
||||||
"auto_scale_batch_size": False,
|
"auto_scale_batch_size": False,
|
||||||
}
|
}
|
||||||
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
|
|
||||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
|
||||||
@unittest.skip("T5 distillation is broken at the moment")
|
|
||||||
def test_distill_t5(self):
|
def test_distill_t5(self):
|
||||||
updates = dict(
|
updates = dict(
|
||||||
student_encoder_layers=1,
|
student_encoder_layers=1,
|
||||||
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
model_name_or_path="sshleifer/tinier_bart",
|
model_name_or_path="sshleifer/tinier_bart",
|
||||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||||
val_check_interval=0.5,
|
val_check_interval=0.5,
|
||||||
alpha_encoder_loss=0.4,
|
|
||||||
)
|
)
|
||||||
default_updates.update(updates)
|
default_updates.update(updates)
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|||||||
Reference in New Issue
Block a user