[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)

* start adding tie encoder to decoder functionality

* finish model tying

* make style

* Apply suggestions from code review

* fix t5 list including cross attention

* apply sams suggestions

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add max depth break point

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-08-19 14:23:45 +02:00
committed by GitHub
parent ab42d74850
commit fe0b85e77a
7 changed files with 288 additions and 26 deletions

View File

@@ -14,6 +14,8 @@
# limitations under the License.
import copy
import tempfile
import unittest
from transformers import is_torch_available
@@ -130,7 +132,7 @@ class T5ModelTester:
# all items after square
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_t5_model(
def create_and_check_model(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
@@ -156,7 +158,7 @@ class T5ModelTester:
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(
def create_and_check_with_lm_head(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
@@ -170,7 +172,7 @@ class T5ModelTester:
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_t5_decoder_model_past(
def create_and_check_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder().to(torch_device).eval()
@@ -201,7 +203,7 @@ class T5ModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_t5_decoder_model_attention_mask_past(
def create_and_check_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder()
@@ -245,7 +247,7 @@ class T5ModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_t5_and_check_t5_generate_with_past_key_value_states(
def create_and_check_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
@@ -257,13 +259,83 @@ class T5ModelTester:
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_t5_model_fp16_forward(
def create_and_check_model_fp16_forward(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_encoder_decoder_shared_weights(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
for model_class in [T5Model, T5ForConditionalGeneration]:
torch.manual_seed(0)
model = model_class(config=config).to(torch_device).eval()
# load state dict copies weights but does not tie them
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
torch.manual_seed(0)
tied_config = copy.deepcopy(config)
tied_config.tie_encoder_decoder = True
tied_model = model_class(config=tied_config).to(torch_device).eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = model_class.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx],
tied_model_result[0][0, :, random_slice_idx],
atol=1e-4,
)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs
@@ -299,30 +371,34 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
def test_t5_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
def test_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_t5_generate_with_past_key_value_states(self):
def test_generate_with_past_key_value_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
def test_encoder_decoder_shared_weights(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_t5_model_fp16_forward(self):
def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
@@ -331,8 +407,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def test_export_to_onnx(self):
import tempfile
config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = T5Model(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname: