From b53bc55ba9bb10d5ee279eab51a2f0acc5af2a6b Mon Sep 17 00:00:00 2001 From: Vishal Burman Date: Mon, 21 Jun 2021 19:06:44 +0530 Subject: [PATCH] Fix for making student ProphetNet for Seq2Seq Distillation (#12130) * make_student.py: fix to make student ProphetNet * reformat --- .../seq2seq-distillation/make_student.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/seq2seq-distillation/make_student.py b/examples/research_projects/seq2seq-distillation/make_student.py index 2ccff5efde..8d70292d0e 100644 --- a/examples/research_projects/seq2seq-distillation/make_student.py +++ b/examples/research_projects/seq2seq-distillation/make_student.py @@ -118,12 +118,18 @@ def create_student_by_copying_alternating_layers( d = teacher_d init_kwargs.update({"encoder_layers": e, "decoder_layers": d}) except AttributeError: # T5 - teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers + if hasattr(teacher.config, "num_encoder_layers"): + teacher_e, teacher_d = teacher.config.num_encoder_layers, teacher.config.num_decoder_layers + else: + teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers if e is None: e = teacher_e if d is None: d = teacher_d - init_kwargs.update({"num_layers": e, "num_decoder_layers": d}) + if hasattr(teacher.config, "num_encoder_layers"): + init_kwargs.update({"num_encoder_layers": e, "num_decoder_layers": d}) + else: + init_kwargs.update({"num_layers": e, "num_decoder_layers": d}) # Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs init_kwargs.update(extra_config_kwargs) @@ -150,8 +156,14 @@ def create_student_by_copying_alternating_layers( d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d) try: - copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) - copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) + if hasattr( + teacher, "prophetnet" + ): # For ProphetNet, student.model.encoder.layers is called student.prophetnet.encoder.layers + copy_layers(teacher.prophetnet.encoder.layers, student.prophetnet.encoder.layers, e_layers_to_copy) + copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy) + else: + copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) + copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)