[s2s] add BartTranslationDistiller for distilling mBART (#6363)

This commit is contained in:
Sam Shleifer
2020-08-12 11:41:04 -04:00
committed by GitHub
parent d2370e1bd8
commit f94a52cd79
2 changed files with 98 additions and 26 deletions

View File

@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
# TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss)
def test_distill_mbart(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
task="translation",
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
tokenizer_name=MBART_TINY,
teacher=MBART_TINY,
src_lang="en_XX",
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
assert len(all_files) > 2
self.assertEqual(len(transformer_ckpts), 2)
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self):
updates = dict(
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing_eps=0.0,
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,