Fix for making student ProphetNet for Seq2Seq Distillation (#12130)
* make_student.py: fix to make student ProphetNet * reformat
This commit is contained in:
@@ -118,12 +118,18 @@ def create_student_by_copying_alternating_layers(
|
|||||||
d = teacher_d
|
d = teacher_d
|
||||||
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
||||||
except AttributeError: # T5
|
except AttributeError: # T5
|
||||||
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
|
if hasattr(teacher.config, "num_encoder_layers"):
|
||||||
|
teacher_e, teacher_d = teacher.config.num_encoder_layers, teacher.config.num_decoder_layers
|
||||||
|
else:
|
||||||
|
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
|
||||||
if e is None:
|
if e is None:
|
||||||
e = teacher_e
|
e = teacher_e
|
||||||
if d is None:
|
if d is None:
|
||||||
d = teacher_d
|
d = teacher_d
|
||||||
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
|
if hasattr(teacher.config, "num_encoder_layers"):
|
||||||
|
init_kwargs.update({"num_encoder_layers": e, "num_decoder_layers": d})
|
||||||
|
else:
|
||||||
|
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
|
||||||
|
|
||||||
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
|
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
|
||||||
init_kwargs.update(extra_config_kwargs)
|
init_kwargs.update(extra_config_kwargs)
|
||||||
@@ -150,8 +156,14 @@ def create_student_by_copying_alternating_layers(
|
|||||||
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
|
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
if hasattr(
|
||||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
teacher, "prophetnet"
|
||||||
|
): # For ProphetNet, student.model.encoder.layers is called student.prophetnet.encoder.layers
|
||||||
|
copy_layers(teacher.prophetnet.encoder.layers, student.prophetnet.encoder.layers, e_layers_to_copy)
|
||||||
|
copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy)
|
||||||
|
else:
|
||||||
|
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
||||||
|
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
||||||
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
|
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
|
||||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||||
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
||||||
|
|||||||
Reference in New Issue
Block a user