[T5] allow config.decoder_layers to control decoder size (#7409)
* Working assymmetrical T5 * rename decoder_layers -> num_decoder_layers * Fix docstring * Allow creation of asymmetric t5 students
This commit is contained in:
@@ -116,12 +116,14 @@ def create_student_by_copying_alternating_layers(
|
|||||||
d = teacher_d
|
d = teacher_d
|
||||||
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
||||||
except AttributeError: # T5
|
except AttributeError: # T5
|
||||||
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers
|
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
|
||||||
assert e == d, "T5 Students must be symmetric"
|
if e is None:
|
||||||
init_kwargs["num_layers"] = e
|
e = teacher_e
|
||||||
|
if d is None:
|
||||||
# Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs
|
d = teacher_d
|
||||||
|
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
|
||||||
|
|
||||||
|
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
|
||||||
init_kwargs.update(extra_config_kwargs)
|
init_kwargs.update(extra_config_kwargs)
|
||||||
|
|
||||||
# Copy weights
|
# Copy weights
|
||||||
|
|||||||
@@ -21,10 +21,8 @@ class MakeStudentTester(unittest.TestCase):
|
|||||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
|
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
|
||||||
self.assertEqual(student.config.num_hidden_layers, 1)
|
self.assertEqual(student.config.num_hidden_layers, 1)
|
||||||
|
|
||||||
def test_invalid_t5(self):
|
def test_asymmetric_t5(self):
|
||||||
# T5 students must have the same e==d because there is only one config property
|
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
|
|
||||||
|
|
||||||
def test_same_decoder_small_encoder(self):
|
def test_same_decoder_small_encoder(self):
|
||||||
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
|
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ class T5Config(PretrainedConfig):
|
|||||||
Size of the intermediate feed forward layer in each :obj:`T5Block`.
|
Size of the intermediate feed forward layer in each :obj:`T5Block`.
|
||||||
num_layers (:obj:`int`, `optional`, defaults to 6):
|
num_layers (:obj:`int`, `optional`, defaults to 6):
|
||||||
Number of hidden layers in the Transformer encoder.
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_decoder_layers (:obj:`int`, `optional`):
|
||||||
|
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not set.
|
||||||
num_heads (:obj:`int`, `optional`, defaults to 8):
|
num_heads (:obj:`int`, `optional`, defaults to 8):
|
||||||
Number of attention heads for each attention layer in
|
Number of attention heads for each attention layer in
|
||||||
the Transformer encoder.
|
the Transformer encoder.
|
||||||
@@ -80,6 +82,7 @@ class T5Config(PretrainedConfig):
|
|||||||
d_kv=64,
|
d_kv=64,
|
||||||
d_ff=2048,
|
d_ff=2048,
|
||||||
num_layers=6,
|
num_layers=6,
|
||||||
|
num_decoder_layers=None,
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
relative_attention_num_buckets=32,
|
relative_attention_num_buckets=32,
|
||||||
dropout_rate=0.1,
|
dropout_rate=0.1,
|
||||||
@@ -102,6 +105,9 @@ class T5Config(PretrainedConfig):
|
|||||||
self.d_kv = d_kv
|
self.d_kv = d_kv
|
||||||
self.d_ff = d_ff
|
self.d_ff = d_ff
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.num_decoder_layers = (
|
||||||
|
num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
||||||
|
) # default = symmetry
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
|
|||||||
@@ -907,7 +907,7 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
T5_START_DOCSTRING,
|
T5_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class T5Model(T5PreTrainedModel):
|
class T5Model(T5PreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
|
|
||||||
@@ -919,6 +919,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
decoder_config = copy.deepcopy(config)
|
decoder_config = copy.deepcopy(config)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.is_encoder_decoder = False
|
decoder_config.is_encoder_decoder = False
|
||||||
|
decoder_config.num_layers = config.num_decoder_layers
|
||||||
self.decoder = T5Stack(decoder_config, self.shared)
|
self.decoder = T5Stack(decoder_config, self.shared)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
@@ -1077,6 +1078,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
decoder_config = copy.deepcopy(config)
|
decoder_config = copy.deepcopy(config)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.is_encoder_decoder = False
|
decoder_config.is_encoder_decoder = False
|
||||||
|
decoder_config.num_layers = config.num_decoder_layers
|
||||||
self.decoder = T5Stack(decoder_config, self.shared)
|
self.decoder = T5Stack(decoder_config, self.shared)
|
||||||
|
|
||||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class T5ModelTester:
|
|||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
decoder_start_token_id=0,
|
decoder_start_token_id=0,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
decoder_layers=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@@ -83,6 +84,7 @@ class T5ModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.decoder_start_token_id = decoder_start_token_id
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
self.scope = None
|
self.scope = None
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||||
@@ -105,6 +107,7 @@ class T5ModelTester:
|
|||||||
d_ff=self.d_ff,
|
d_ff=self.d_ff,
|
||||||
d_kv=self.hidden_size // self.num_attention_heads,
|
d_kv=self.hidden_size // self.num_attention_heads,
|
||||||
num_layers=self.num_hidden_layers,
|
num_layers=self.num_hidden_layers,
|
||||||
|
num_decoder_layers=self.decoder_layers,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
@@ -623,3 +626,40 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||||||
output = model.generate(**inputs)
|
output = model.generate(**inputs)
|
||||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
self.assertEqual(translation, expected_translation)
|
self.assertEqual(translation, expected_translation)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TestAsymmetricT5(unittest.TestCase):
|
||||||
|
def build_model_and_check_forward_pass(self, **kwargs):
|
||||||
|
tester = T5ModelTester(self, **kwargs)
|
||||||
|
config, *inputs = tester.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask,
|
||||||
|
decoder_attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = inputs
|
||||||
|
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||||
|
outputs = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
labels=lm_labels,
|
||||||
|
)
|
||||||
|
# outputs = model(*inputs)
|
||||||
|
assert len(outputs) == 4
|
||||||
|
assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size)
|
||||||
|
assert outputs["loss"].size() == ()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def test_small_decoder(self):
|
||||||
|
# num_hidden_layers is passed to T5Config as num_layers
|
||||||
|
model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2)
|
||||||
|
assert len(model.encoder.block) == 2
|
||||||
|
assert len(model.decoder.block) == 1
|
||||||
|
|
||||||
|
def test_defaulting_to_symmetry(self):
|
||||||
|
# num_hidden_layers is passed to T5Config as num_layers
|
||||||
|
model = self.build_model_and_check_forward_pass(num_hidden_layers=2)
|
||||||
|
assert len(model.decoder.block) == len(model.encoder.block) == 2
|
||||||
|
|||||||
Reference in New Issue
Block a user