[s2s] add BartTranslationDistiller for distilling mBART (#6363)
This commit is contained in:
@@ -10,20 +10,40 @@ from torch import nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from lightning_base import generic_train
|
from lightning_base import generic_train
|
||||||
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
|
from transformers import (
|
||||||
|
AdamW,
|
||||||
|
BartConfig,
|
||||||
|
BartForConditionalGeneration,
|
||||||
|
MBartTokenizer,
|
||||||
|
T5Config,
|
||||||
|
T5ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .finetune import SummarizationModule
|
from .finetune import SummarizationModule, TranslationModule
|
||||||
from .finetune import main as ft_main
|
|
||||||
from .initialization_utils import init_student, copy_layers
|
from .initialization_utils import init_student, copy_layers
|
||||||
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
from .utils import (
|
||||||
|
use_task_specific_params,
|
||||||
|
pickle_load,
|
||||||
|
freeze_params,
|
||||||
|
assert_all_frozen,
|
||||||
|
any_requires_grad,
|
||||||
|
calculate_bleu_score,
|
||||||
|
)
|
||||||
|
from .finetune import main as ft_main
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from finetune import SummarizationModule
|
from finetune import SummarizationModule, TranslationModule
|
||||||
from finetune import main as ft_main
|
from finetune import main as ft_main
|
||||||
from initialization_utils import init_student, copy_layers
|
from initialization_utils import init_student, copy_layers
|
||||||
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
from utils import (
|
||||||
|
use_task_specific_params,
|
||||||
|
pickle_load,
|
||||||
|
freeze_params,
|
||||||
|
assert_all_frozen,
|
||||||
|
any_requires_grad,
|
||||||
|
calculate_bleu_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BartSummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
add_distill_args(parser)
|
||||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
|
||||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
|
||||||
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
|
||||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
|
||||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
|
||||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
|
||||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
|
||||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
|
||||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def _step(self, batch):
|
def _step(self, batch):
|
||||||
@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
return sum(hidden_losses)
|
return sum(hidden_losses)
|
||||||
|
|
||||||
|
|
||||||
|
def add_distill_args(parser):
|
||||||
|
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||||
|
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||||
|
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||||
|
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||||
|
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||||
|
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||||
|
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||||
|
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||||
|
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class BartTranslationDistiller(BartSummarizationDistiller):
|
||||||
|
mode = "translation"
|
||||||
|
loss_names = ["loss"]
|
||||||
|
metric_names = ["bleu"]
|
||||||
|
val_metric = "bleu"
|
||||||
|
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
assert isinstance(self.tokenizer, MBartTokenizer)
|
||||||
|
assert hparams.src_lang is not None
|
||||||
|
assert hparams.tgt_lang is not None
|
||||||
|
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||||
|
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||||
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||||
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
|
|
||||||
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
|
return calculate_bleu_score(preds, target)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_model_specific_args(parser, root_dir):
|
||||||
|
TranslationModule.add_model_specific_args(parser, root_dir)
|
||||||
|
add_distill_args(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class T5SummarizationDistiller(BartSummarizationDistiller):
|
class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||||
def pre_init(self, hparams):
|
def pre_init(self, hparams):
|
||||||
raise NotImplementedError("T5 Distillation does not work yet")
|
raise NotImplementedError("T5 Distillation does not work yet")
|
||||||
@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
|||||||
def create_module(args):
|
def create_module(args):
|
||||||
t5 = "t5" in args.model_name_or_path
|
t5 = "t5" in args.model_name_or_path
|
||||||
if args.no_teacher:
|
if args.no_teacher:
|
||||||
assert not args.enc_only
|
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
||||||
module_cls = SummarizationModule
|
elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION
|
||||||
elif t5:
|
assert "translation" not in args.task, "t5 translation distillation not supported"
|
||||||
module_cls = T5SummarizationDistiller
|
module_cls = T5SummarizationDistiller
|
||||||
elif args.enc_only:
|
else: # DISTILL WITH TEACHER
|
||||||
raise ValueError("Deleted that")
|
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
|
||||||
else:
|
|
||||||
module_cls = BartSummarizationDistiller
|
|
||||||
args.setup_cls: str = module_cls.__name__
|
args.setup_cls: str = module_cls.__name__
|
||||||
|
print(f"using module {args.setup_cls}")
|
||||||
model = module_cls(args)
|
model = module_cls(args)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
# TODO: understand why this breaks
|
# TODO: understand why this breaks
|
||||||
self.assertEqual(nll_loss, model_computed_loss)
|
self.assertEqual(nll_loss, model_computed_loss)
|
||||||
|
|
||||||
|
def test_distill_mbart(self):
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=2,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
num_train_epochs=4,
|
||||||
|
val_check_interval=0.25,
|
||||||
|
alpha_hid=2.0,
|
||||||
|
task="translation",
|
||||||
|
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||||
|
tokenizer_name=MBART_TINY,
|
||||||
|
teacher=MBART_TINY,
|
||||||
|
src_lang="en_XX",
|
||||||
|
tgt_lang="ro_RO",
|
||||||
|
)
|
||||||
|
model = self._test_distiller_cli(updates, check_contents=False)
|
||||||
|
|
||||||
|
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||||
|
self.assertEqual(1, len(ckpts))
|
||||||
|
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||||
|
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
|
||||||
|
assert len(all_files) > 2
|
||||||
|
self.assertEqual(len(transformer_ckpts), 2)
|
||||||
|
|
||||||
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
|
||||||
@unittest.skip("T5 distillation is broken at the moment")
|
@unittest.skip("T5 distillation is broken at the moment")
|
||||||
def test_distill_t5(self):
|
def test_distill_t5(self):
|
||||||
updates = dict(
|
updates = dict(
|
||||||
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
|
|
||||||
def _test_distiller_cli(self, updates, check_contents=True):
|
def _test_distiller_cli(self, updates, check_contents=True):
|
||||||
default_updates = dict(
|
default_updates = dict(
|
||||||
label_smoothing_eps=0.0,
|
label_smoothing=0.0,
|
||||||
early_stopping_patience=-1,
|
early_stopping_patience=-1,
|
||||||
train_batch_size=1,
|
train_batch_size=1,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
|
|||||||
Reference in New Issue
Block a user