Support T5 Distillation w/hidden state supervision (#7599)

This commit is contained in:
Sam Shleifer
2020-10-05 21:31:48 -04:00
committed by GitHub
parent 818c294fdd
commit d5d2744aa7
2 changed files with 36 additions and 29 deletions

View File

@@ -86,7 +86,6 @@ CHEAP_ARGS = {
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_encoder_loss": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()