[s2s] add BartTranslationDistiller for distilling mBART (#6363)
This commit is contained in:
@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
# TODO: understand why this breaks
|
||||
self.assertEqual(nll_loss, model_computed_loss)
|
||||
|
||||
def test_distill_mbart(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
task="translation",
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
tokenizer_name=MBART_TINY,
|
||||
teacher=MBART_TINY,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
|
||||
assert len(all_files) > 2
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
@unittest.skip("T5 distillation is broken at the moment")
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing_eps=0.0,
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
|
||||
Reference in New Issue
Block a user