[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)

* start adding tie encoder to decoder functionality

* finish model tying

* make style

* Apply suggestions from code review

* fix t5 list including cross attention

* apply sams suggestions

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add max depth break point

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-08-19 14:23:45 +02:00
committed by GitHub
parent ab42d74850
commit fe0b85e77a
7 changed files with 288 additions and 26 deletions

View File

@@ -268,6 +268,88 @@ class EncoderDecoderMixin:
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def create_and_check_encoder_decoder_shared_weights(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
torch.manual_seed(0)
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.to(torch_device)
model.eval()
# load state dict copies weights but does not tie them
decoder_state_dict = model.decoder._modules[model.decoder.base_model_prefix].state_dict()
model.encoder.load_state_dict(decoder_state_dict, strict=False)
torch.manual_seed(0)
tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(config, decoder_config)
config = EncoderDecoderConfig.from_encoder_decoder_configs(
tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True
)
tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config)
tied_model.to(torch_device)
tied_model.eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()))
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict)
@@ -296,6 +378,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def test_encoder_decoder_model_shared_weights(self):
input_ids_dict = self.prepare_config_and_inputs()
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
@@ -480,3 +566,6 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def test_encoder_decoder_model_shared_weights(self):
pass