[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)
* start adding tie encoder to decoder functionality * finish model tying * make style * Apply suggestions from code review * fix t5 list including cross attention * apply sams suggestions * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add max depth break point Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ab42d74850
commit
fe0b85e77a
@@ -268,6 +268,88 @@ class EncoderDecoderMixin:
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def create_and_check_encoder_decoder_shared_weights(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
# load state dict copies weights but does not tie them
|
||||
decoder_state_dict = model.decoder._modules[model.decoder.base_model_prefix].state_dict()
|
||||
model.encoder.load_state_dict(decoder_state_dict, strict=False)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(
|
||||
tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True
|
||||
)
|
||||
tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config)
|
||||
tied_model.to(torch_device)
|
||||
tied_model.eval()
|
||||
|
||||
model_result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
tied_model_result = tied_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# check that models has less parameters
|
||||
self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()))
|
||||
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
|
||||
|
||||
# check that outputs are equal
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
# check that outputs after saving and loading are equal
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tied_model.save_pretrained(tmpdirname)
|
||||
tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
tied_model.to(torch_device)
|
||||
tied_model.eval()
|
||||
|
||||
# check that models has less parameters
|
||||
self.assertLess(
|
||||
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
|
||||
)
|
||||
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
|
||||
|
||||
tied_model_result = tied_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# check that outputs are equal
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
def test_encoder_decoder_model(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model(**input_ids_dict)
|
||||
@@ -296,6 +378,10 @@ class EncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_shared_weights(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
@@ -480,3 +566,6 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
|
||||
|
||||
def test_encoder_decoder_model_shared_weights(self):
|
||||
pass
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
@@ -130,7 +132,7 @@ class T5ModelTester:
|
||||
# all items after square
|
||||
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
|
||||
|
||||
def create_and_check_t5_model(
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config)
|
||||
@@ -156,7 +158,7 @@ class T5ModelTester:
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
|
||||
def create_and_check_t5_with_lm_head(
|
||||
def create_and_check_with_lm_head(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
@@ -170,7 +172,7 @@ class T5ModelTester:
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_t5_decoder_model_past(
|
||||
def create_and_check_decoder_model_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder().to(torch_device).eval()
|
||||
@@ -201,7 +203,7 @@ class T5ModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_t5_decoder_model_attention_mask_past(
|
||||
def create_and_check_decoder_model_attention_mask_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder()
|
||||
@@ -245,7 +247,7 @@ class T5ModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_t5_and_check_t5_generate_with_past_key_value_states(
|
||||
def create_and_check_generate_with_past_key_value_states(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
@@ -257,13 +259,83 @@ class T5ModelTester:
|
||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||
|
||||
def create_and_check_t5_model_fp16_forward(
|
||||
def create_and_check_model_fp16_forward(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).to(torch_device).half().eval()
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_encoder_decoder_shared_weights(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
for model_class in [T5Model, T5ForConditionalGeneration]:
|
||||
torch.manual_seed(0)
|
||||
model = model_class(config=config).to(torch_device).eval()
|
||||
# load state dict copies weights but does not tie them
|
||||
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tied_config = copy.deepcopy(config)
|
||||
tied_config.tie_encoder_decoder = True
|
||||
tied_model = model_class(config=tied_config).to(torch_device).eval()
|
||||
|
||||
model_result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
tied_model_result = tied_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# check that models has less parameters
|
||||
self.parent.assertLess(
|
||||
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
|
||||
)
|
||||
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
|
||||
|
||||
# check that outputs are equal
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(
|
||||
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
# check that outputs after saving and loading are equal
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tied_model.save_pretrained(tmpdirname)
|
||||
tied_model = model_class.from_pretrained(tmpdirname)
|
||||
tied_model.to(torch_device)
|
||||
tied_model.eval()
|
||||
|
||||
# check that models has less parameters
|
||||
self.parent.assertLess(
|
||||
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
|
||||
)
|
||||
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
|
||||
|
||||
tied_model_result = tied_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# check that outputs are equal
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(
|
||||
model_result[0][0, :, random_slice_idx],
|
||||
tied_model_result[0][0, :, random_slice_idx],
|
||||
atol=1e-4,
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs
|
||||
@@ -299,30 +371,34 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
|
||||
|
||||
def test_t5_model(self):
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_with_lm_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
|
||||
|
||||
def test_t5_decoder_model_past(self):
|
||||
def test_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
|
||||
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||
|
||||
def test_t5_decoder_model_past_with_attn_mask(self):
|
||||
def test_decoder_model_past_with_attn_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_t5_generate_with_past_key_value_states(self):
|
||||
def test_generate_with_past_key_value_states(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
|
||||
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
|
||||
|
||||
def test_encoder_decoder_shared_weights(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||
def test_t5_model_fp16_forward(self):
|
||||
def test_model_fp16_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@@ -331,8 +407,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_export_to_onnx(self):
|
||||
import tempfile
|
||||
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = T5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user