From 6078b12098337bcb98c0540b07a623223ffdd1c8 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 4 Sep 2020 14:05:56 -0400 Subject: [PATCH] [s2s] distill: --normalize_hidden --supervise_forward (#6834) --- .../seq2seq/distil_marian_enro_teacher.sh | 7 ++- examples/seq2seq/distillation.py | 54 ++++++++++++++----- examples/seq2seq/test_seq2seq_examples.py | 2 + 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/examples/seq2seq/distil_marian_enro_teacher.sh b/examples/seq2seq/distil_marian_enro_teacher.sh index 575ecf8502..75ef07bc06 100755 --- a/examples/seq2seq/distil_marian_enro_teacher.sh +++ b/examples/seq2seq/distil_marian_enro_teacher.sh @@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar python distillation.py \ --learning_rate=3e-4 \ --do_train \ - --do_predict \ --fp16 \ --val_check_interval 0.25 \ - --teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \ + --teacher Helsinki-NLP/opus-mt-en-ro \ --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ --student_decoder_layers 3 --student_encoder_layers 6 \ --freeze_encoder --freeze_embeds \ @@ -16,6 +15,6 @@ python distillation.py \ --alpha_hid=3. \ --train_batch_size=$BS --eval_batch_size=$BS \ --tokenizer_name Helsinki-NLP/opus-mt-en-ro \ - --warmup_steps 500 --sortish_sampler --logger_name wandb \ - --gpus 1 --fp16_opt_level O1 --task translation \ + --warmup_steps 500 --logger_name wandb \ + --fp16_opt_level O1 --task translation --normalize_hidden \ "$@" diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 9d9bd3e66a..f9ac39cbc2 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule): } if hparams.length_penalty != -1: student_updates["length_penalty"] = hparams.length_penalty - d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) - hparams.d_layer_to_copy = d_layers_to_copy hparams.e_layer_to_copy = e_layers_to_copy + + d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) + + if hparams.supervise_forward: + hparams.d_matches = get_layers_to_supervise( + student_updates["decoder_layers"], teacher.config.decoder_layers + ) + else: + hparams.d_matches = d_layers_to_copy + hparams.d_layer_to_copy = d_layers_to_copy + kw = teacher.config.to_diff_dict() kw.update(student_updates) # Copy weights @@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule): dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: - hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy) + hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) blended_loss = ( self.alpha_ce * loss_ce @@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule): assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" mask = attention_mask.to(hidden_states[0]) valid_count = mask.sum() * hidden_states[0].size(-1) - hidden_losses = [ - (F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum() - / valid_count - for i, j in enumerate(matches) - ] - return sum(hidden_losses) + student_states = torch.stack([hidden_states[i] for i in range(len(matches))]) + teacher_states = torch.stack([hidden_states_T[j] for j in matches]) + if self.hparams.normalize_hidden: + student_states = F.layer_norm(student_states, student_states.shape[1:]) + teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:]) + mse = F.mse_loss(student_states, teacher_states, reduction="none") + masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count + return masked_mse def add_distill_args(parser): @@ -255,6 +266,8 @@ def add_distill_args(parser): parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) parser.add_argument("--no_teacher", action="store_true", default=False) parser.add_argument("--length_penalty", type=float, default=-1) + parser.add_argument("--supervise_forward", action="store_true", default=False) + parser.add_argument("--normalize_hidden", action="store_true", default=False) class BartTranslationDistiller(BartSummarizationDistiller): @@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller): loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits) if self.alpha_hid > 0: - hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy) + hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) blended_loss = ( self.alpha_ce * loss_ce @@ -463,15 +476,28 @@ LAYERS_TO_COPY = { }, 6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))}, } +LAYERS_TO_SUPERVISE = { + 12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]}, + 16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]}, + 6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]}, + 2: {1: [1], 2: [0, 1]}, +} + + +def get_layers_to_supervise(n_student, n_teacher): + return LAYERS_TO_SUPERVISE[n_teacher][n_student] def get_layers_to_copy(n_student, n_teacher): try: - return LAYERS_TO_COPY[n_teacher][n_student] + val = LAYERS_TO_COPY[n_teacher][n_student] + assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student + return val except KeyError: - warnings.warn( - f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}" - ) + if n_student != n_teacher: + warnings.warn( + f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}" + ) return list(range(n_student)) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 92d74eeefa..5121e165aa 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { + "supervise_forward": True, + "normalize_hidden": True, "label_smoothing": 0.2, "eval_beams": 1, "val_metric": "loss",