[T5] allow config.decoder_layers to control decoder size (#7409)
* Working assymmetrical T5 * rename decoder_layers -> num_decoder_layers * Fix docstring * Allow creation of asymmetric t5 students
This commit is contained in:
@@ -116,12 +116,14 @@ def create_student_by_copying_alternating_layers(
|
||||
d = teacher_d
|
||||
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
||||
except AttributeError: # T5
|
||||
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers
|
||||
assert e == d, "T5 Students must be symmetric"
|
||||
init_kwargs["num_layers"] = e
|
||||
|
||||
# Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs
|
||||
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
|
||||
if e is None:
|
||||
e = teacher_e
|
||||
if d is None:
|
||||
d = teacher_d
|
||||
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
|
||||
|
||||
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
|
||||
init_kwargs.update(extra_config_kwargs)
|
||||
|
||||
# Copy weights
|
||||
|
||||
@@ -21,10 +21,8 @@ class MakeStudentTester(unittest.TestCase):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
|
||||
self.assertEqual(student.config.num_hidden_layers, 1)
|
||||
|
||||
def test_invalid_t5(self):
|
||||
# T5 students must have the same e==d because there is only one config property
|
||||
with self.assertRaises(AssertionError):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
|
||||
def test_asymmetric_t5(self):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
|
||||
|
||||
def test_same_decoder_small_encoder(self):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
|
||||
|
||||
Reference in New Issue
Block a user