Support T5 Distillation w/hidden state supervision (#7599)
This commit is contained in:
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"alpha_encoder_loss": 0.0,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
@unittest.skip("T5 distillation is broken at the moment")
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
alpha_encoder_loss=0.4,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
Reference in New Issue
Block a user