[s2s] distill: --normalize_hidden --supervise_forward (#6834)
This commit is contained in:
@@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
}
|
||||
if hparams.length_penalty != -1:
|
||||
student_updates["length_penalty"] = hparams.length_penalty
|
||||
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
hparams.e_layer_to_copy = e_layers_to_copy
|
||||
|
||||
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
|
||||
if hparams.supervise_forward:
|
||||
hparams.d_matches = get_layers_to_supervise(
|
||||
student_updates["decoder_layers"], teacher.config.decoder_layers
|
||||
)
|
||||
else:
|
||||
hparams.d_matches = d_layers_to_copy
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
|
||||
kw = teacher.config.to_diff_dict()
|
||||
kw.update(student_updates)
|
||||
# Copy weights
|
||||
@@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
@@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
||||
mask = attention_mask.to(hidden_states[0])
|
||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||
hidden_losses = [
|
||||
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
|
||||
/ valid_count
|
||||
for i, j in enumerate(matches)
|
||||
]
|
||||
return sum(hidden_losses)
|
||||
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
||||
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||
if self.hparams.normalize_hidden:
|
||||
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
||||
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
||||
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
|
||||
return masked_mse
|
||||
|
||||
|
||||
def add_distill_args(parser):
|
||||
@@ -255,6 +266,8 @@ def add_distill_args(parser):
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||
parser.add_argument("--supervise_forward", action="store_true", default=False)
|
||||
parser.add_argument("--normalize_hidden", action="store_true", default=False)
|
||||
|
||||
|
||||
class BartTranslationDistiller(BartSummarizationDistiller):
|
||||
@@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
@@ -463,15 +476,28 @@ LAYERS_TO_COPY = {
|
||||
},
|
||||
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
||||
}
|
||||
LAYERS_TO_SUPERVISE = {
|
||||
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
|
||||
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
|
||||
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
|
||||
2: {1: [1], 2: [0, 1]},
|
||||
}
|
||||
|
||||
|
||||
def get_layers_to_supervise(n_student, n_teacher):
|
||||
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
|
||||
|
||||
|
||||
def get_layers_to_copy(n_student, n_teacher):
|
||||
try:
|
||||
return LAYERS_TO_COPY[n_teacher][n_student]
|
||||
val = LAYERS_TO_COPY[n_teacher][n_student]
|
||||
assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student
|
||||
return val
|
||||
except KeyError:
|
||||
warnings.warn(
|
||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||
)
|
||||
if n_student != n_teacher:
|
||||
warnings.warn(
|
||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||
)
|
||||
return list(range(n_student))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user