s2s distillation uses AutoModelForSeqToSeqLM (#6761)
This commit is contained in:
@@ -10,7 +10,7 @@ from torch import nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from lightning_base import generic_train
|
from lightning_base import generic_train
|
||||||
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
def pre_init(self, hparams):
|
def pre_init(self, hparams):
|
||||||
self.output_dir = Path(hparams.output_dir)
|
self.output_dir = Path(hparams.output_dir)
|
||||||
self.output_dir.mkdir(exist_ok=True)
|
self.output_dir.mkdir(exist_ok=True)
|
||||||
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
||||||
student_updates = {
|
student_updates = {
|
||||||
"decoder_layers": hparams.student_decoder_layers,
|
"decoder_layers": hparams.student_decoder_layers,
|
||||||
"encoder_layers": hparams.student_encoder_layers,
|
"encoder_layers": hparams.student_encoder_layers,
|
||||||
}
|
}
|
||||||
if hparams.length_penalty != -1:
|
if hparams.length_penalty != -1:
|
||||||
student_updates["length_penalty"] = hparams.length_penalty
|
student_updates["length_penalty"] = hparams.length_penalty
|
||||||
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||||
hparams.d_layer_to_copy = d_layers_to_copy
|
hparams.d_layer_to_copy = d_layers_to_copy
|
||||||
hparams.e_layer_to_copy = e_layers_to_copy
|
hparams.e_layer_to_copy = e_layers_to_copy
|
||||||
kw = teacher.config.to_diff_dict()
|
kw = teacher.config.to_diff_dict()
|
||||||
kw.update(student_updates)
|
kw.update(student_updates)
|
||||||
# Copy weights
|
# Copy weights
|
||||||
student_cfg = BartConfig(**kw)
|
student_cfg = teacher.config_class(**kw)
|
||||||
student = BartForConditionalGeneration(student_cfg)
|
student = type(teacher)(student_cfg)
|
||||||
student, _ = init_student(student, teacher)
|
student, _ = init_student(student, teacher)
|
||||||
save_dir = self.output_dir.joinpath("student")
|
save_dir = self.output_dir.joinpath("student")
|
||||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||||
@@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):
|
|||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
assert isinstance(self.tokenizer, MBartTokenizer)
|
|
||||||
assert hparams.src_lang is not None
|
assert hparams.src_lang is not None
|
||||||
assert hparams.tgt_lang is not None
|
assert hparams.tgt_lang is not None
|
||||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||||
|
|||||||
@@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
tgt_lang="ro_RO",
|
tgt_lang="ro_RO",
|
||||||
)
|
)
|
||||||
model = self._test_distiller_cli(updates, check_contents=False)
|
model = self._test_distiller_cli(updates, check_contents=False)
|
||||||
|
assert model.model.config.model_type == "mbart"
|
||||||
|
|
||||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||||
self.assertEqual(1, len(ckpts))
|
self.assertEqual(1, len(ckpts))
|
||||||
|
|||||||
Reference in New Issue
Block a user