[s2s] add BartTranslationDistiller for distilling mBART (#6363)
This commit is contained in:
@@ -10,20 +10,40 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
|
||||
from transformers import (
|
||||
AdamW,
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
MBartTokenizer,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule
|
||||
from .finetune import main as ft_main
|
||||
from .finetune import SummarizationModule, TranslationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
|
||||
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||
|
||||
add_distill_args(parser)
|
||||
return parser
|
||||
|
||||
def _step(self, batch):
|
||||
@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
return sum(hidden_losses)
|
||||
|
||||
|
||||
def add_distill_args(parser):
|
||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||
|
||||
|
||||
class BartTranslationDistiller(BartSummarizationDistiller):
|
||||
mode = "translation"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
assert isinstance(self.tokenizer, MBartTokenizer)
|
||||
assert hparams.src_lang is not None
|
||||
assert hparams.tgt_lang is not None
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
TranslationModule.add_model_specific_args(parser, root_dir)
|
||||
add_distill_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
def pre_init(self, hparams):
|
||||
raise NotImplementedError("T5 Distillation does not work yet")
|
||||
@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
def create_module(args):
|
||||
t5 = "t5" in args.model_name_or_path
|
||||
if args.no_teacher:
|
||||
assert not args.enc_only
|
||||
module_cls = SummarizationModule
|
||||
elif t5:
|
||||
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
||||
elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION
|
||||
assert "translation" not in args.task, "t5 translation distillation not supported"
|
||||
module_cls = T5SummarizationDistiller
|
||||
elif args.enc_only:
|
||||
raise ValueError("Deleted that")
|
||||
else:
|
||||
module_cls = BartSummarizationDistiller
|
||||
else: # DISTILL WITH TEACHER
|
||||
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
|
||||
args.setup_cls: str = module_cls.__name__
|
||||
print(f"using module {args.setup_cls}")
|
||||
model = module_cls(args)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user