s2s distillation uses AutoModelForSeqToSeqLM (#6761)

This commit is contained in:
Sam Shleifer
2020-08-26 23:25:11 -04:00
committed by GitHub
parent 05e7150a53
commit 4bd7be9a42
2 changed files with 6 additions and 6 deletions

View File

@@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))