From 748425d47d450482f1a99738cda2b8576f0c755d Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 28 Sep 2020 03:08:04 -0400 Subject: [PATCH] [T5] allow config.decoder_layers to control decoder size (#7409) * Working assymmetrical T5 * rename decoder_layers -> num_decoder_layers * Fix docstring * Allow creation of asymmetric t5 students --- examples/seq2seq/make_student.py | 12 ++++---- examples/seq2seq/test_make_student.py | 6 ++-- src/transformers/configuration_t5.py | 6 ++++ src/transformers/modeling_t5.py | 4 ++- tests/test_modeling_t5.py | 40 +++++++++++++++++++++++++++ 5 files changed, 58 insertions(+), 10 deletions(-) diff --git a/examples/seq2seq/make_student.py b/examples/seq2seq/make_student.py index 5150edef3a..6a19ef9aa1 100644 --- a/examples/seq2seq/make_student.py +++ b/examples/seq2seq/make_student.py @@ -116,12 +116,14 @@ def create_student_by_copying_alternating_layers( d = teacher_d init_kwargs.update({"encoder_layers": e, "decoder_layers": d}) except AttributeError: # T5 - teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers - assert e == d, "T5 Students must be symmetric" - init_kwargs["num_layers"] = e - - # Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs + teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers + if e is None: + e = teacher_e + if d is None: + d = teacher_d + init_kwargs.update({"num_layers": e, "num_decoder_layers": d}) + # Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs init_kwargs.update(extra_config_kwargs) # Copy weights diff --git a/examples/seq2seq/test_make_student.py b/examples/seq2seq/test_make_student.py index 9f33069a80..0a1688a95c 100644 --- a/examples/seq2seq/test_make_student.py +++ b/examples/seq2seq/test_make_student.py @@ -21,10 +21,8 @@ class MakeStudentTester(unittest.TestCase): student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1) self.assertEqual(student.config.num_hidden_layers, 1) - def test_invalid_t5(self): - # T5 students must have the same e==d because there is only one config property - with self.assertRaises(AssertionError): - student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None) + def test_asymmetric_t5(self): + student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None) def test_same_decoder_small_encoder(self): student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None) diff --git a/src/transformers/configuration_t5.py b/src/transformers/configuration_t5.py index 3457f37632..a7b602c1c1 100644 --- a/src/transformers/configuration_t5.py +++ b/src/transformers/configuration_t5.py @@ -57,6 +57,8 @@ class T5Config(PretrainedConfig): Size of the intermediate feed forward layer in each :obj:`T5Block`. num_layers (:obj:`int`, `optional`, defaults to 6): Number of hidden layers in the Transformer encoder. + num_decoder_layers (:obj:`int`, `optional`): + Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not set. num_heads (:obj:`int`, `optional`, defaults to 8): Number of attention heads for each attention layer in the Transformer encoder. @@ -80,6 +82,7 @@ class T5Config(PretrainedConfig): d_kv=64, d_ff=2048, num_layers=6, + num_decoder_layers=None, num_heads=8, relative_attention_num_buckets=32, dropout_rate=0.1, @@ -102,6 +105,9 @@ class T5Config(PretrainedConfig): self.d_kv = d_kv self.d_ff = d_ff self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry self.num_heads = num_heads self.relative_attention_num_buckets = relative_attention_num_buckets self.dropout_rate = dropout_rate diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 71b49b3419..a3b70c492c 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -907,7 +907,7 @@ T5_INPUTS_DOCSTRING = r""" T5_START_DOCSTRING, ) class T5Model(T5PreTrainedModel): - def __init__(self, config): + def __init__(self, config: T5Config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) @@ -919,6 +919,7 @@ class T5Model(T5PreTrainedModel): decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers self.decoder = T5Stack(decoder_config, self.shared) self.init_weights() @@ -1077,6 +1078,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers self.decoder = T5Stack(decoder_config, self.shared) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index c5e3ec9d1b..0edb54016d 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -59,6 +59,7 @@ class T5ModelTester: pad_token_id=0, decoder_start_token_id=0, scope=None, + decoder_layers=None, ): self.parent = parent @@ -83,6 +84,7 @@ class T5ModelTester: self.pad_token_id = pad_token_id self.decoder_start_token_id = decoder_start_token_id self.scope = None + self.decoder_layers = decoder_layers def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) @@ -105,6 +107,7 @@ class T5ModelTester: d_ff=self.d_ff, d_kv=self.hidden_size // self.num_attention_heads, num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, num_heads=self.num_attention_heads, relative_attention_num_buckets=self.relative_attention_num_buckets, dropout_rate=self.dropout_rate, @@ -623,3 +626,40 @@ class T5ModelIntegrationTests(unittest.TestCase): output = model.generate(**inputs) translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) self.assertEqual(translation, expected_translation) + + +@require_torch +class TestAsymmetricT5(unittest.TestCase): + def build_model_and_check_forward_pass(self, **kwargs): + tester = T5ModelTester(self, **kwargs) + config, *inputs = tester.prepare_config_and_inputs() + ( + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = inputs + model = T5ForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + # outputs = model(*inputs) + assert len(outputs) == 4 + assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size) + assert outputs["loss"].size() == () + return model + + def test_small_decoder(self): + # num_hidden_layers is passed to T5Config as num_layers + model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2) + assert len(model.encoder.block) == 2 + assert len(model.decoder.block) == 1 + + def test_defaulting_to_symmetry(self): + # num_hidden_layers is passed to T5Config as num_layers + model = self.build_model_and_check_forward_pass(num_hidden_layers=2) + assert len(model.decoder.block) == len(model.encoder.block) == 2