diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index e5ac03bbd1..5861ed21c9 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -10,20 +10,40 @@ from torch import nn from torch.nn import functional as F from lightning_base import generic_train -from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration +from transformers import ( + AdamW, + BartConfig, + BartForConditionalGeneration, + MBartTokenizer, + T5Config, + T5ForConditionalGeneration, +) try: - from .finetune import SummarizationModule - from .finetune import main as ft_main + from .finetune import SummarizationModule, TranslationModule from .initialization_utils import init_student, copy_layers - from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad - + from .utils import ( + use_task_specific_params, + pickle_load, + freeze_params, + assert_all_frozen, + any_requires_grad, + calculate_bleu_score, + ) + from .finetune import main as ft_main except ImportError: - from finetune import SummarizationModule + from finetune import SummarizationModule, TranslationModule from finetune import main as ft_main from initialization_utils import init_student, copy_layers - from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad + from utils import ( + use_task_specific_params, + pickle_load, + freeze_params, + assert_all_frozen, + any_requires_grad, + calculate_bleu_score, + ) class BartSummarizationDistiller(SummarizationModule): @@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule): @staticmethod def add_model_specific_args(parser, root_dir): SummarizationModule.add_model_specific_args(parser, root_dir) - parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) - parser.add_argument("--alpha_ce", default=0.8, type=float) - parser.add_argument("--alpha_mlm", default=0.2, type=float) - # parser.add_argument("--alpha_cos", default=0.0, type=float) - parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) - parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) - parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) - parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) - parser.add_argument("--no_teacher", action="store_true", default=False) - parser.add_argument("--length_penalty", type=float, default=-1) - + add_distill_args(parser) return parser def _step(self, batch): @@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule): return sum(hidden_losses) +def add_distill_args(parser): + parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) + parser.add_argument("--alpha_ce", default=0.8, type=float) + parser.add_argument("--alpha_mlm", default=0.2, type=float) + parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) + parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) + parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) + parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) + parser.add_argument("--no_teacher", action="store_true", default=False) + parser.add_argument("--length_penalty", type=float, default=-1) + + +class BartTranslationDistiller(BartSummarizationDistiller): + mode = "translation" + loss_names = ["loss"] + metric_names = ["bleu"] + val_metric = "bleu" + + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + assert isinstance(self.tokenizer, MBartTokenizer) + assert hparams.src_lang is not None + assert hparams.tgt_lang is not None + self.dataset_kwargs["src_lang"] = hparams.src_lang + self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang + if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] + + def calc_generative_metrics(self, preds, target) -> dict: + return calculate_bleu_score(preds, target) + + @staticmethod + def add_model_specific_args(parser, root_dir): + TranslationModule.add_model_specific_args(parser, root_dir) + add_distill_args(parser) + return parser + + class T5SummarizationDistiller(BartSummarizationDistiller): def pre_init(self, hparams): raise NotImplementedError("T5 Distillation does not work yet") @@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller): def create_module(args): t5 = "t5" in args.model_name_or_path if args.no_teacher: - assert not args.enc_only - module_cls = SummarizationModule - elif t5: + module_cls = TranslationModule if "translation" in args.task else SummarizationModule + elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION + assert "translation" not in args.task, "t5 translation distillation not supported" module_cls = T5SummarizationDistiller - elif args.enc_only: - raise ValueError("Deleted that") - else: - module_cls = BartSummarizationDistiller + else: # DISTILL WITH TEACHER + module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller args.setup_cls: str = module_cls.__name__ + print(f"using module {args.setup_cls}") model = module_cls(args) return model diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 70b34d1d8a..4d02190c82 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase): # TODO: understand why this breaks self.assertEqual(nll_loss, model_computed_loss) + def test_distill_mbart(self): + updates = dict( + student_encoder_layers=2, + student_decoder_layers=1, + num_train_epochs=4, + val_check_interval=0.25, + alpha_hid=2.0, + task="translation", + model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED", + tokenizer_name=MBART_TINY, + teacher=MBART_TINY, + src_lang="en_XX", + tgt_lang="ro_RO", + ) + model = self._test_distiller_cli(updates, check_contents=False) + + ckpts = list(Path(model.output_dir).glob("*.ckpt")) + self.assertEqual(1, len(ckpts)) + transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) + all_files = list(Path(model.output_dir).glob("best_tfmr/*")) + assert len(all_files) > 2 + self.assertEqual(len(transformer_ckpts), 2) + + evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) + @unittest.skip("T5 distillation is broken at the moment") def test_distill_t5(self): updates = dict( @@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase): def _test_distiller_cli(self, updates, check_contents=True): default_updates = dict( - label_smoothing_eps=0.0, + label_smoothing=0.0, early_stopping_patience=-1, train_batch_size=1, eval_batch_size=2,