[T5] allow config.decoder_layers to control decoder size (#7409)
* Working assymmetrical T5 * rename decoder_layers -> num_decoder_layers * Fix docstring * Allow creation of asymmetric t5 students
This commit is contained in:
@@ -59,6 +59,7 @@ class T5ModelTester:
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0,
|
||||
scope=None,
|
||||
decoder_layers=None,
|
||||
):
|
||||
|
||||
self.parent = parent
|
||||
@@ -83,6 +84,7 @@ class T5ModelTester:
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.scope = None
|
||||
self.decoder_layers = decoder_layers
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
@@ -105,6 +107,7 @@ class T5ModelTester:
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_decoder_layers=self.decoder_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
@@ -623,3 +626,40 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
output = model.generate(**inputs)
|
||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAsymmetricT5(unittest.TestCase):
|
||||
def build_model_and_check_forward_pass(self, **kwargs):
|
||||
tester = T5ModelTester(self, **kwargs)
|
||||
config, *inputs = tester.prepare_config_and_inputs()
|
||||
(
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
) = inputs
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=lm_labels,
|
||||
)
|
||||
# outputs = model(*inputs)
|
||||
assert len(outputs) == 4
|
||||
assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size)
|
||||
assert outputs["loss"].size() == ()
|
||||
return model
|
||||
|
||||
def test_small_decoder(self):
|
||||
# num_hidden_layers is passed to T5Config as num_layers
|
||||
model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2)
|
||||
assert len(model.encoder.block) == 2
|
||||
assert len(model.decoder.block) == 1
|
||||
|
||||
def test_defaulting_to_symmetry(self):
|
||||
# num_hidden_layers is passed to T5Config as num_layers
|
||||
model = self.build_model_and_check_forward_pass(num_hidden_layers=2)
|
||||
assert len(model.decoder.block) == len(model.encoder.block) == 2
|
||||
|
||||
Reference in New Issue
Block a user