From eab5f59682cf197cd5fd19d499b3670dbef67000 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 28 Sep 2020 00:40:46 +0530 Subject: [PATCH] [s2s] add create student script (#7290) Co-authored-by: Suraj Patil Co-authored-by: Sam Shleifer --- examples/seq2seq/README.md | 4 +- examples/seq2seq/distillation.py | 349 ++++-------------- examples/seq2seq/initialization_utils.py | 20 - examples/seq2seq/make_student.py | 167 +++++++++ .../save_randomly_initialized_model.py | 26 ++ .../seq2seq/test_data/wmt_en_ro/train.len | Bin 34 -> 26 bytes examples/seq2seq/test_data/wmt_en_ro/val.len | Bin 48 -> 40 bytes examples/seq2seq/test_make_student.py | 41 ++ examples/seq2seq/train_distilbart_xsum.sh | 19 +- 9 files changed, 312 insertions(+), 314 deletions(-) delete mode 100644 examples/seq2seq/initialization_utils.py create mode 100644 examples/seq2seq/make_student.py create mode 100755 examples/seq2seq/save_randomly_initialized_model.py create mode 100644 examples/seq2seq/test_make_student.py diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 9eaec29213..2efde25c50 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -369,7 +369,7 @@ runtime: 6H on NVIDIA RTX 24GB GPU If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent, because you will have the same hyperparameters logged in every run. -#### With a teacher +#### With a teacher (Intermediate Supervision) *Note* only BART variants are supported In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`. @@ -378,7 +378,7 @@ This is how `sshleifer/distilbart-xsum*` checkpoints were produced. The command that produced `sshleifer/distilbart-xsum-12-6` is: ```bash -./train_distilbart_xsum.sh +./train_distilbart_xsum.sh --logger_name wandb --gpus 1 ``` runtime: 13H on V-100 16GB GPU. diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 2ad6d32a6d..1e8ff40306 100755 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -4,7 +4,6 @@ import argparse import gc import os import sys -import warnings from pathlib import Path from typing import List @@ -15,18 +14,10 @@ from torch.nn import functional as F from finetune import SummarizationModule, TranslationModule from finetune import main as ft_main -from initialization_utils import copy_layers, init_student -from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration +from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise +from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration from transformers.modeling_bart import shift_tokens_right -from utils import ( - any_requires_grad, - assert_all_frozen, - calculate_bleu, - freeze_params, - label_smoothed_nll_loss, - pickle_load, - use_task_specific_params, -) +from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, pickle_load, use_task_specific_params # need the parent dir module @@ -41,87 +32,50 @@ class BartSummarizationDistiller(SummarizationModule): def __init__(self, hparams): assert Path(hparams.data_dir).exists() - student, student_cfg, teacher = self.pre_init(hparams) + self.output_dir = Path(hparams.output_dir) + self.output_dir.mkdir(exist_ok=True) - super().__init__(hparams, model=student, config=student_cfg) + save_dir = self.output_dir.joinpath("student") + + hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student + teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval() + use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default + student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers( + teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir + ) + if hparams.length_penalty != -1: + student.config.length_penalty = hparams.length_penalty + super().__init__(hparams, model=student, config=student.config) + self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int] + self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers + self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers self.teacher = teacher - use_task_specific_params(self.teacher, "summarization") freeze_params(self.teacher) - self.sanity_check_gradients() + + if not self.different_encoder: # To save RAM, delete teacher encoder and freeze student encoder. + try: + del self.teacher.model.encoder + except AttributeError: # T5 + del self.teacher.encoder + # Intermediate supervision: Decide which layers to supervise + if hparams.supervise_forward: + self.d_matches = get_layers_to_supervise( + n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers + ) + else: + self.d_matches = self.d_layer_ids self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") self.temperature = 2.0 self.alpha_mlm = hparams.alpha_mlm self.alpha_ce = hparams.alpha_ce self.alpha_hid = hparams.alpha_hid - # self.alpha_cos = hparams.alpha_cos - self.alpha_encoder_loss = self.hparams.alpha_encoder_loss + self.alpha_encoder_loss = hparams.alpha_encoder_loss gc.collect() torch.cuda.empty_cache() - def sanity_check_gradients(self): - assert_all_frozen(self.teacher) - assert_all_frozen(self.model.model.decoder.embed_tokens) - assert_all_frozen(self.model.model.encoder.embed_tokens) - if self.different_encoder: - assert any_requires_grad(self.model.model.encoder) - else: - freeze_params(self.model.model.encoder) - del self.teacher.model.encoder - - def pre_init(self, hparams): - self.output_dir = Path(hparams.output_dir) - self.output_dir.mkdir(exist_ok=True) - teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval() - student_updates = { - "decoder_layers": hparams.student_decoder_layers, - "encoder_layers": hparams.student_encoder_layers, - } - if hparams.length_penalty != -1: - student_updates["length_penalty"] = hparams.length_penalty - e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) - hparams.e_layer_to_copy = e_layers_to_copy - - d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) - - if hparams.supervise_forward: - hparams.d_matches = get_layers_to_supervise( - student_updates["decoder_layers"], teacher.config.decoder_layers - ) - else: - hparams.d_matches = d_layers_to_copy - hparams.d_layer_to_copy = d_layers_to_copy - - kw = teacher.config.to_diff_dict() - kw.update(student_updates) - # Copy weights - student_cfg = teacher.config_class(**kw) - student = type(teacher)(student_cfg) - student, _ = init_student(student, teacher) - save_dir = self.output_dir.joinpath("student") - self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) - student.save_pretrained(save_dir) - hparams.model_name_or_path = str(save_dir) - return student, student_cfg, teacher - - def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher): - if teacher.config.model_type == "t5": - return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) - self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers - self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers - if self.different_decoder: - copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) - if self.different_encoder: - copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) - - def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher): - self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers - self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers - if self.different_decoder: - copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy) - if self.different_encoder: - copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) - def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor: + """Supervise MSE(teacher.encoder_outputs, student.encoder_outputs).""" + # raise NotImplementedError() if mask is not None: # mask has False at padding_idx sel_mask = mask[:, :, None].expand_as(student_outputs).bool() @@ -133,20 +87,15 @@ class BartSummarizationDistiller(SummarizationModule): return F.mse_loss(s_logits_slct, t_logits_slct) def calc_ce_loss(self, mask, s_logits, t_logits): - if mask is not None: - # mask has False at padding_idx - sel_mask = mask[:, :, None].expand_as(s_logits) - s_logits_slct = torch.masked_select( - s_logits, sel_mask - ) # (bs * seq_length * voc_size) modulo the 1s in mask - t_logits_slct = torch.masked_select( - t_logits, sel_mask - ) # (bs * seq_length * voc_size) modulo the 1s in mask - else: - t_logits_slct = t_logits - s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask - s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask - t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + """Copy pasted from distillbert (transformers/examples/distillation/)""" + + # mask has False at padding_idx + sel_mask = mask[:, :, None].expand_as(s_logits) + vocab_size = s_logits.size(-1) + s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask + t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask + s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask + t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask assert t_logits_slct.size() == s_logits_slct.size() loss_ce = ( self.ce_loss_fct( @@ -155,7 +104,7 @@ class BartSummarizationDistiller(SummarizationModule): ) * (self.temperature) ** 2 ) - return loss_ce, s_logits_slct, t_logits_slct + return loss_ce @staticmethod def add_model_specific_args(parser, root_dir): @@ -164,10 +113,14 @@ class BartSummarizationDistiller(SummarizationModule): return parser def _step(self, batch): - # assert is_frozen(self.teacher) + # assert is_frozen(self.teacher) copied_decoder_layers pad_token_id = self.tokenizer.pad_token_id - input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"] - decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) + input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"] + if isinstance(self.model, T5ForConditionalGeneration): + decoder_input_ids = self.model._shift_right(labels) + else: + decoder_input_ids = shift_tokens_right(labels, pad_token_id) + # noinspection PyCallingNonCallable lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( input_ids, @@ -183,11 +136,11 @@ class BartSummarizationDistiller(SummarizationModule): if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) - student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) + student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( - lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id + lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id ) def zero_tensor(): @@ -196,15 +149,14 @@ class BartSummarizationDistiller(SummarizationModule): loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() if self.different_encoder: with torch.no_grad(): - teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder( + teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, output_hidden_states=True ) + # DEPRECATE THIS if self.hparams.alpha_encoder_loss > 0: loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) - hid_loss_enc = self.calc_hidden_loss( - src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy - ) + hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids) teacher_enc_outputs = (enc_outputs,) assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) @@ -215,13 +167,15 @@ class BartSummarizationDistiller(SummarizationModule): attention_mask=src_mask, encoder_outputs=teacher_enc_outputs, decoder_input_ids=decoder_input_ids, - lm_labels=tgt_ids, + lm_labels=labels, output_hidden_states=True, ) dec_mask = decoder_input_ids.ne(pad_token_id) - loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits) - if self.alpha_hid > 0: - hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) + loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits) + if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states + hid_loss_dec = self.calc_hidden_loss( + dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden + ) blended_loss = ( self.alpha_ce * loss_ce @@ -231,7 +185,9 @@ class BartSummarizationDistiller(SummarizationModule): ) return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec - def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches): + @staticmethod + def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden): + """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT.""" msg = "expected list or tuple for hidden_states, got tensor of shape: " assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}" assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" @@ -239,7 +195,7 @@ class BartSummarizationDistiller(SummarizationModule): valid_count = mask.sum() * hidden_states[0].size(-1) student_states = torch.stack([hidden_states[i] for i in range(len(matches))]) teacher_states = torch.stack([hidden_states_T[j] for j in matches]) - if self.hparams.normalize_hidden: + if normalize_hidden: student_states = F.layer_norm(student_states, student_states.shape[1:]) teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:]) mse = F.mse_loss(student_states, teacher_states, reduction="none") @@ -287,130 +243,9 @@ class BartTranslationDistiller(BartSummarizationDistiller): return parser -class T5SummarizationDistiller(BartSummarizationDistiller): - def pre_init(self, hparams): - raise NotImplementedError("T5 Distillation does not work yet") - self.output_dir = Path(hparams.output_dir) - self.output_dir.mkdir(exist_ok=True) - teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher) - n_layer = hparams.student_decoder_layers - assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6. - d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block)) - e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block)) - student_updates = {"num_layers": n_layer} - hparams.d_layer_to_copy = d_layers_to_copy - hparams.e_layer_to_copy = e_layers_to_copy - kw = teacher.config.to_diff_dict() - - kw.update(student_updates) - # Copy weights - student_cfg = T5Config(**kw) - student = T5ForConditionalGeneration(student_cfg) - student, _ = init_student(student, teacher) - self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) - Path(hparams.output_dir).mkdir(exist_ok=True) - task_specific_params = student.config.task_specific_params - if task_specific_params is not None: - student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode - save_dir = self.output_dir.joinpath("student") - save_dir.mkdir(exist_ok=True) - - student.save_pretrained(save_dir) - hparams.model_name_or_path = str(save_dir) - return student, student_cfg, teacher - - def freeze_embeds(self): - freeze_params(self.model.shared) - for d in [self.model.encoder, self.model.decoder]: - freeze_params(d.embed_tokens) - - def sanity_check_gradients(self): - """T5""" - assert_all_frozen(self.teacher) - assert_all_frozen(self.model.decoder.embed_tokens) - assert_all_frozen(self.model.encoder.embed_tokens) - if self.different_encoder: - assert any_requires_grad(self.model.encoder) - else: - freeze_params(self.model.encoder) - del self.teacher.model.encoder - if self.different_decoder: - assert any_requires_grad(self.model.decoder) - else: - freeze_params(self.model.decoder) # TODO(SS): very suspicious - - def _step(self, batch): - pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] - decoder_input_ids = y[:, :-1].contiguous() - labels = y[:, 1:].clone() - labels[y[:, 1:] == pad_token_id] = -100 - # noinspection PyCallingNonCallable - dec_mask = decoder_input_ids.ne(pad_token_id) - - sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self( - source_ids, - attention_mask=source_mask, - decoder_input_ids=decoder_input_ids, - labels=labels, - output_hidden_states=True, - output_attentions=False, - use_cache=False, - ) - - def zero_tensor(): - return torch.tensor(0.0).type_as(sloss) - - loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() - if self.different_encoder: - with torch.no_grad(): - teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder( - source_ids, - attention_mask=source_mask, - output_hidden_states=True, - use_cache=False, - ) - if self.hparams.alpha_encoder_loss > 0: - loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask) - - hid_loss_enc = self.calc_hidden_loss( - source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy - ) - - teacher_enc_outputs = (enc_outputs,) - assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) - - with torch.no_grad(): - tloss, tlogits, tdec_hidden, _ = self.teacher( - source_ids, - attention_mask=source_mask, - encoder_outputs=teacher_enc_outputs, - decoder_input_ids=decoder_input_ids, - labels=labels, - output_hidden_states=True, - use_cache=False, - ) - - loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits) - if self.alpha_hid > 0: - hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) - - blended_loss = ( - self.alpha_ce * loss_ce - + self.alpha_mlm * sloss - + self.hparams.alpha_encoder_loss * loss_encoder - + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) - ) - return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec - - def create_module(args): - t5 = "t5" in args.model_name_or_path if args.no_teacher: module_cls = TranslationModule if "translation" in args.task else SummarizationModule - elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION - assert "translation" not in args.task, "t5 translation distillation not supported" - module_cls = T5SummarizationDistiller else: # DISTILL WITH TEACHER module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller args.setup_cls: str = module_cls.__name__ @@ -443,56 +278,6 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): trainer.test(model) -LAYERS_TO_COPY = { - # maps num layers in student -> which teacher layers to copy. - # 12: bart, 16: pegasus, 6: marian/Helsinki-NLP - 12: { - 1: [0], - 2: [0, 6], - 3: [0, 6, 11], - 4: [0, 4, 8, 11], - 6: [0, 2, 4, 7, 9, 11], - 9: [0, 1, 2, 4, 5, 7, 9, 10, 11], - 12: list(range(12)), - }, - 16: { # maps num layers in student -> which teacher layers to copy - 1: [0], - 2: [0, 8], - 3: [0, 8, 15], - 4: [0, 5, 10, 15], - 6: [0, 3, 6, 9, 12, 15], - 8: [0, 2, 4, 6, 8, 10, 12, 15], - 9: [0, 1, 3, 5, 7, 9, 11, 13, 15], - 12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15], - 16: list(range(16)), - }, - 6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))}, -} -LAYERS_TO_SUPERVISE = { - 12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]}, - 16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]}, - 6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]}, - 2: {1: [1], 2: [0, 1]}, -} - - -def get_layers_to_supervise(n_student, n_teacher): - return LAYERS_TO_SUPERVISE[n_teacher][n_student] - - -def get_layers_to_copy(n_student, n_teacher): - try: - val = LAYERS_TO_COPY[n_teacher][n_student] - assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student - return val - except KeyError: - if n_student != n_teacher: - warnings.warn( - f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}" - ) - return list(range(n_student)) - - def distill_main(args): Path(args.output_dir).mkdir(exist_ok=True) if len(os.listdir(args.output_dir)) > 3 and args.do_train: diff --git a/examples/seq2seq/initialization_utils.py b/examples/seq2seq/initialization_utils.py deleted file mode 100644 index 02cba8b352..0000000000 --- a/examples/seq2seq/initialization_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import List - -from torch import nn - - -def init_student(student, teacher): - teacher_state_dict = teacher.state_dict() - info = student.load_state_dict(teacher_state_dict, strict=False) - assert info.missing_keys == [], info.missing_keys - return student, info - - -def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]): - copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy) - - -def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None: - layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy]) - assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}" - student_layers.load_state_dict(layers_to_copy.state_dict()) diff --git a/examples/seq2seq/make_student.py b/examples/seq2seq/make_student.py new file mode 100644 index 0000000000..5150edef3a --- /dev/null +++ b/examples/seq2seq/make_student.py @@ -0,0 +1,167 @@ +import warnings +from pathlib import Path +from typing import List, Tuple, Union + +import fire +from torch import nn + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None: + layers_to_copy = nn.ModuleList([l for i, l in enumerate(src_layers) if i in layers_to_copy]) + assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}" + dest_layers.load_state_dict(layers_to_copy.state_dict()) + + +LAYERS_TO_COPY = { + # maps num layers in teacher -> num_layers in student -> which teacher layers to copy. + # 12: bart, 16: pegasus, 6: marian/Helsinki-NLP + 12: { + 1: [0], # This says that if the teacher has 12 layers and the student has 1, copy layer 0 of the teacher + 2: [0, 6], + 3: [0, 6, 11], + 4: [0, 4, 8, 11], + 6: [0, 2, 4, 7, 9, 11], + 9: [0, 1, 2, 4, 5, 7, 9, 10, 11], + 12: list(range(12)), + }, + 16: { # maps num layers in student -> which teacher layers to copy + 1: [0], + 2: [0, 8], + 3: [0, 8, 15], + 4: [0, 5, 10, 15], + 6: [0, 3, 6, 9, 12, 15], + 8: [0, 2, 4, 6, 8, 10, 12, 15], + 9: [0, 1, 3, 5, 7, 9, 11, 13, 15], + 12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15], + 16: list(range(16)), + }, + 6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))}, +} +LAYERS_TO_SUPERVISE = { + # maps num layers in student -> which teacher layers to copy. + 6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]}, + 12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]}, + 16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]}, +} + + +def pick_layers_to_copy(n_student, n_teacher): + try: + val = LAYERS_TO_COPY[n_teacher][n_student] + return val + except KeyError: + if n_student != n_teacher: + warnings.warn( + f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}" + ) + return list(range(n_student)) + + +def get_layers_to_supervise(n_student, n_teacher) -> List[int]: + """Used or the --supervise_forward kwarg""" + if n_student > n_teacher: + raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}") + elif n_teacher == n_student: + return list(range(n_teacher)) + elif n_student == 1: + return [n_teacher - 1] + else: + return LAYERS_TO_SUPERVISE[n_teacher][n_student] + + +def create_student_by_copying_alternating_layers( + teacher: Union[str, PreTrainedModel], + save_path: Union[str, Path] = "student", + e: Union[int, None] = None, + d: Union[int, None] = None, + copy_first_teacher_layers=False, + **extra_config_kwargs +) -> Tuple[PreTrainedModel, List[int], List[int]]: + """Make a student by copying alternating layers from a teacher, save it to save_path. + Args: + teacher: str or PreTrainedModel if str, this will call AutoModelForSeq2SeqLM.from_pretrained(teacher) before + copying layers + save_path: where to save the student, defaults to student directory. + e: how many Encoder layers should the student have, default is fully copy of teacher + d: how many Decoder layers should the student have, default is fully copy of teacher + copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d. + **extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used. + + Returns: + student: new, smaller model. (Also saves it to save_path) + e_layers_to_copy: list of which teacher encoder layers were used + d_layers_to_copy: list of which teacher decoder layers were used + """ + _msg = "encoder_layers and decoder_layers cannot be both None-- you would just have an identical teacher." + assert (e is not None) or (d is not None), _msg + if isinstance(teacher, str): + AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience + teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval() + else: + + assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}" + init_kwargs = teacher.config.to_diff_dict() + + try: + teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers + if e is None: + e = teacher_e + if d is None: + d = teacher_d + init_kwargs.update({"encoder_layers": e, "decoder_layers": d}) + except AttributeError: # T5 + teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers + assert e == d, "T5 Students must be symmetric" + init_kwargs["num_layers"] = e + + # Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs + + init_kwargs.update(extra_config_kwargs) + + # Copy weights + student_cfg = teacher.config_class(**init_kwargs) + student = AutoModelForSeq2SeqLM.from_config(student_cfg) + # Start by copying the full teacher state dict this will copy the first N teacher layers to the student. + info = student.load_state_dict(teacher.state_dict(), strict=False) + assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys. + + if copy_first_teacher_layers: # Our copying is done. We just log and save + e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d)) + logger.info( + f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}" + ) + student.save_pretrained(save_path) + return student, e_layers_to_copy, d_layers_to_copy + + # Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer. + e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e) + d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d) + + try: + copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) + copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) + except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block + copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) + copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy) + logger.info( + f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}" + ) + student.config.init_metadata = dict( + teacher_type=teacher.config.model_type, + copied_encoder_layers=e_layers_to_copy, + copied_decoder_layers=d_layers_to_copy, + ) + student.save_pretrained(save_path) + # Save information about copying for easier reproducibility + + return student, e_layers_to_copy, d_layers_to_copy + + +if __name__ == "__main__": + fire.Fire(create_student_by_copying_alternating_layers) diff --git a/examples/seq2seq/save_randomly_initialized_model.py b/examples/seq2seq/save_randomly_initialized_model.py new file mode 100755 index 0000000000..c4a18afb7e --- /dev/null +++ b/examples/seq2seq/save_randomly_initialized_model.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python + +import fire + +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer + + +def save_randomly_initialized_version(config_name: str, save_dir: str, **config_kwargs): + """Save a randomly initialized version of a model using a pretrained config. + Args: + config_name: which config to use + save_dir: where to save the resulting model and tokenizer + config_kwargs: Passed to AutoConfig + + Usage:: + save_randomly_initialized_version("facebook/bart-large-cnn", "distilbart_random_cnn_6_3", encoder_layers=6, decoder_layers=3, num_beams=3) + """ + cfg = AutoConfig.from_pretrained(config_name, **config_kwargs) + model = AutoModelForSeq2SeqLM.from_config(cfg) + model.save_pretrained(save_dir) + AutoTokenizer.from_pretrained(config_name).save_pretrained(save_dir) + return model + + +if __name__ == "__main__": + fire.Fire(save_randomly_initialized_version) diff --git a/examples/seq2seq/test_data/wmt_en_ro/train.len b/examples/seq2seq/test_data/wmt_en_ro/train.len index 2632a33e8b8a3a601bdf93f3b8e6783a0c7bba60..33ce003c8ae3139914a389a714812a2ab13aece4 100644 GIT binary patch literal 26 hcmZo*jxA)+@V4-d_nz!s?Op24=3V2R>s_6y2LNfV2s8iy literal 34 lcmZo*nJUfz0kKmwye+)ry(fEDdzX5%dDnR7dRM3F0RV-J2?PKD diff --git a/examples/seq2seq/test_data/wmt_en_ro/val.len b/examples/seq2seq/test_data/wmt_en_ro/val.len index fdf8fa353eb8d41a6d92ccae9df4b2f8aaf970b5..897314a960b28d927b597805693e63f9de71d903 100644 GIT binary patch delta 11 ScmXreU~OQIEo7L;s|)}RKmu?8 delta 19 WcmdNe;B8=;s>%QXu~R0pDgyu?i~~*p diff --git a/examples/seq2seq/test_make_student.py b/examples/seq2seq/test_make_student.py new file mode 100644 index 0000000000..9f33069a80 --- /dev/null +++ b/examples/seq2seq/test_make_student.py @@ -0,0 +1,41 @@ +import tempfile +import unittest + +from make_student import create_student_by_copying_alternating_layers +from transformers import AutoConfig +from transformers.file_utils import cached_property +from transformers.testing_utils import require_torch + + +TINY_BART = "sshleifer/bart-tiny-random" +TINY_T5 = "patrickvonplaten/t5-tiny-random" + + +@require_torch +class MakeStudentTester(unittest.TestCase): + @cached_property + def teacher_config(self): + return AutoConfig.from_pretrained(TINY_BART) + + def test_valid_t5(self): + student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1) + self.assertEqual(student.config.num_hidden_layers, 1) + + def test_invalid_t5(self): + # T5 students must have the same e==d because there is only one config property + with self.assertRaises(AssertionError): + student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None) + + def test_same_decoder_small_encoder(self): + student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None) + self.assertEqual(student.config.encoder_layers, 1) + self.assertEqual(student.config.decoder_layers, self.teacher_config.encoder_layers) + + def test_small_enc_small_dec(self): + student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=1) + self.assertEqual(student.config.encoder_layers, 1) + self.assertEqual(student.config.decoder_layers, 1) + + def test_raises_assert(self): + with self.assertRaises(AssertionError): + create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=None, d=None) diff --git a/examples/seq2seq/train_distilbart_xsum.sh b/examples/seq2seq/train_distilbart_xsum.sh index 4ae5652940..a4c1fff4be 100755 --- a/examples/seq2seq/train_distilbart_xsum.sh +++ b/examples/seq2seq/train_distilbart_xsum.sh @@ -1,21 +1,20 @@ #!/usr/bin/env bash export PYTHONPATH="../":"${PYTHONPATH}" -export BS=16 -export GAS=2 python distillation.py \ + --teacher facebook/bart-large-xsum --data_dir xsum \ + --student_decoder_layers 6 --student_encoder_layers 12 \ + --freeze_encoder --freeze_embeds \ --learning_rate=3e-4 \ --do_train \ --do_predict \ - --fp16 \ - --val_check_interval 0.1 --n_val 1000 \ - --teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \ + --fp16 --fp16_opt_level=O1 \ + --val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ - --student_decoder_layers 6 --student_encoder_layers 12 \ - --freeze_encoder --freeze_embeds \ --model_name_or_path IGNORED \ - --alpha_hid=3. --length_penalty=0.5 \ - --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \ - --tokenizer_name facebook/bart-large \ + --alpha_hid=3. \ + --train_batch_size=16 --eval_batch_size=16 --gradient_accumulation_steps=2 \ + --sortish_sampler \ + --num_train_epochs=6 \ --warmup_steps 500 \ --output_dir distilbart_xsum_12_6 \ "$@"