[s2s] distill: --normalize_hidden --supervise_forward (#6834)
This commit is contained in:
@@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar
|
|||||||
python distillation.py \
|
python distillation.py \
|
||||||
--learning_rate=3e-4 \
|
--learning_rate=3e-4 \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--val_check_interval 0.25 \
|
--val_check_interval 0.25 \
|
||||||
--teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \
|
--teacher Helsinki-NLP/opus-mt-en-ro \
|
||||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||||
--student_decoder_layers 3 --student_encoder_layers 6 \
|
--student_decoder_layers 3 --student_encoder_layers 6 \
|
||||||
--freeze_encoder --freeze_embeds \
|
--freeze_encoder --freeze_embeds \
|
||||||
@@ -16,6 +15,6 @@ python distillation.py \
|
|||||||
--alpha_hid=3. \
|
--alpha_hid=3. \
|
||||||
--train_batch_size=$BS --eval_batch_size=$BS \
|
--train_batch_size=$BS --eval_batch_size=$BS \
|
||||||
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
|
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
|
||||||
--warmup_steps 500 --sortish_sampler --logger_name wandb \
|
--warmup_steps 500 --logger_name wandb \
|
||||||
--gpus 1 --fp16_opt_level O1 --task translation \
|
--fp16_opt_level O1 --task translation --normalize_hidden \
|
||||||
"$@"
|
"$@"
|
||||||
|
|||||||
@@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
}
|
}
|
||||||
if hparams.length_penalty != -1:
|
if hparams.length_penalty != -1:
|
||||||
student_updates["length_penalty"] = hparams.length_penalty
|
student_updates["length_penalty"] = hparams.length_penalty
|
||||||
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
|
||||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||||
hparams.d_layer_to_copy = d_layers_to_copy
|
|
||||||
hparams.e_layer_to_copy = e_layers_to_copy
|
hparams.e_layer_to_copy = e_layers_to_copy
|
||||||
|
|
||||||
|
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||||
|
|
||||||
|
if hparams.supervise_forward:
|
||||||
|
hparams.d_matches = get_layers_to_supervise(
|
||||||
|
student_updates["decoder_layers"], teacher.config.decoder_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hparams.d_matches = d_layers_to_copy
|
||||||
|
hparams.d_layer_to_copy = d_layers_to_copy
|
||||||
|
|
||||||
kw = teacher.config.to_diff_dict()
|
kw = teacher.config.to_diff_dict()
|
||||||
kw.update(student_updates)
|
kw.update(student_updates)
|
||||||
# Copy weights
|
# Copy weights
|
||||||
@@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||||
if self.alpha_hid > 0:
|
if self.alpha_hid > 0:
|
||||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
|
||||||
|
|
||||||
blended_loss = (
|
blended_loss = (
|
||||||
self.alpha_ce * loss_ce
|
self.alpha_ce * loss_ce
|
||||||
@@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
||||||
mask = attention_mask.to(hidden_states[0])
|
mask = attention_mask.to(hidden_states[0])
|
||||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||||
hidden_losses = [
|
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
||||||
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
|
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||||
/ valid_count
|
if self.hparams.normalize_hidden:
|
||||||
for i, j in enumerate(matches)
|
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
||||||
]
|
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||||
return sum(hidden_losses)
|
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
||||||
|
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
|
||||||
|
return masked_mse
|
||||||
|
|
||||||
|
|
||||||
def add_distill_args(parser):
|
def add_distill_args(parser):
|
||||||
@@ -255,6 +266,8 @@ def add_distill_args(parser):
|
|||||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||||
|
parser.add_argument("--supervise_forward", action="store_true", default=False)
|
||||||
|
parser.add_argument("--normalize_hidden", action="store_true", default=False)
|
||||||
|
|
||||||
|
|
||||||
class BartTranslationDistiller(BartSummarizationDistiller):
|
class BartTranslationDistiller(BartSummarizationDistiller):
|
||||||
@@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
|||||||
|
|
||||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||||
if self.alpha_hid > 0:
|
if self.alpha_hid > 0:
|
||||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
|
||||||
|
|
||||||
blended_loss = (
|
blended_loss = (
|
||||||
self.alpha_ce * loss_ce
|
self.alpha_ce * loss_ce
|
||||||
@@ -463,15 +476,28 @@ LAYERS_TO_COPY = {
|
|||||||
},
|
},
|
||||||
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
||||||
}
|
}
|
||||||
|
LAYERS_TO_SUPERVISE = {
|
||||||
|
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
|
||||||
|
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
|
||||||
|
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
|
||||||
|
2: {1: [1], 2: [0, 1]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_layers_to_supervise(n_student, n_teacher):
|
||||||
|
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
|
||||||
|
|
||||||
|
|
||||||
def get_layers_to_copy(n_student, n_teacher):
|
def get_layers_to_copy(n_student, n_teacher):
|
||||||
try:
|
try:
|
||||||
return LAYERS_TO_COPY[n_teacher][n_student]
|
val = LAYERS_TO_COPY[n_teacher][n_student]
|
||||||
|
assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student
|
||||||
|
return val
|
||||||
except KeyError:
|
except KeyError:
|
||||||
warnings.warn(
|
if n_student != n_teacher:
|
||||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
warnings.warn(
|
||||||
)
|
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||||
|
)
|
||||||
return list(range(n_student))
|
return list(range(n_student))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
|
"supervise_forward": True,
|
||||||
|
"normalize_hidden": True,
|
||||||
"label_smoothing": 0.2,
|
"label_smoothing": 0.2,
|
||||||
"eval_beams": 1,
|
"eval_beams": 1,
|
||||||
"val_metric": "loss",
|
"val_metric": "loss",
|
||||||
|
|||||||
Reference in New Issue
Block a user