From fe0b85e77a6af041471657069bbb9c21a880cd5c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Aug 2020 14:23:45 +0200 Subject: [PATCH] [EncoderDecoder] Add functionality to tie encoder decoder weights (#6538) * start adding tie encoder to decoder functionality * finish model tying * make style * Apply suggestions from code review * fix t5 list including cross attention * apply sams suggestions * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add max depth break point Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../configuration_encoder_decoder.py | 4 +- src/transformers/configuration_utils.py | 3 + src/transformers/modeling_encoder_decoder.py | 28 ++++- src/transformers/modeling_t5.py | 4 + src/transformers/modeling_utils.py | 74 +++++++++++- tests/test_modeling_encoder_decoder.py | 89 ++++++++++++++ tests/test_modeling_t5.py | 112 +++++++++++++++--- 7 files changed, 288 insertions(+), 26 deletions(-) diff --git a/src/transformers/configuration_encoder_decoder.py b/src/transformers/configuration_encoder_decoder.py index 95cabaa82e..65c4021d3b 100644 --- a/src/transformers/configuration_encoder_decoder.py +++ b/src/transformers/configuration_encoder_decoder.py @@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig): @classmethod def from_encoder_decoder_configs( - cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig + cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs ) -> PretrainedConfig: r""" Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration. @@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig): decoder_config.is_decoder = True decoder_config.add_cross_attention = True - return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict()) + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) def to_dict(self): """ diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2ccc0492ac..26a8f3ca35 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -58,6 +58,8 @@ class PretrainedConfig(object): Whether the model is used as decoder or not (in which case it's used as an encoder). add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``. + tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`) + Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names. prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`): Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of heads to prune in said layer. @@ -153,6 +155,7 @@ class PretrainedConfig(object): self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.is_decoder = kwargs.pop("is_decoder", False) self.add_cross_attention = kwargs.pop("add_cross_attention", False) + self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) # Parameters for sequence generation self.max_length = kwargs.pop("max_length", 20) diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 664b4181f5..b351d32142 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -71,9 +71,17 @@ class EncoderDecoderModel(PreTrainedModel): self.encoder.get_output_embeddings() is None ), "The encoder {} should not have a LM Head. Please use a model without LM Head" + # tie encoder, decoder weights if config set accordingly + self.tie_weights() + def tie_weights(self): - # for now no weights tying in encoder-decoder - pass + # tie encoder & decoder if needed + if self.config.tie_encoder_decoder: + # tie encoder and decoder base model + decoder_base_model_prefix = self.decoder.base_model_prefix + self._tie_encoder_decoder_weights( + self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix + ) def get_encoder(self): return self.encoder @@ -122,7 +130,11 @@ class EncoderDecoderModel(PreTrainedModel): All remaning positional arguments will be passed to the underlying model's ``__init__`` method kwargs: (`optional`) Remaining dictionary of keyword arguments. - Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). + - To update the encoder configuration, use the prefix `encoder_` for each configuration parameter + - To update the decoder configuration, use the prefix `decoder_` for each configuration parameter + - To update the parent model configuration, do not use a prefix for each configuration parameter + Behave differently depending on whether a :obj:`config` is provided or automatically loaded. Examples:: @@ -144,6 +156,12 @@ class EncoderDecoderModel(PreTrainedModel): argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. @@ -184,7 +202,9 @@ class EncoderDecoderModel(PreTrainedModel): decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - return cls(encoder=encoder, decoder=decoder) + # instantiate config with corresponding kwargs + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + return cls(encoder=encoder, decoder=decoder, config=config) def forward( self, diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 3b718f242f..0b2d55f575 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -887,10 +887,12 @@ class T5Model(T5PreTrainedModel): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False self.decoder = T5Stack(decoder_config, self.shared) self.init_weights() @@ -1040,10 +1042,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False self.decoder = T5Stack(decoder_config, self.shared) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b5aad402b2..62300c1a61 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -416,6 +416,77 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): if output_embeddings is not None: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: + self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + + @staticmethod + def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): + uninitialized_encoder_weights: List[str] = [] + assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal." + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" + if hasattr(decoder_pointer, "weight"): + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and substract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) + if len(uninitialized_encoder_weights) > 0: + logger.warning( + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" + ) + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): """ Tie or clone module weights depending of whether we are using TorchScript or not """ @@ -894,7 +965,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): model.__class__.__name__, "\n\t".join(error_msgs) ) ) - model.tie_weights() # make sure token embedding weights are still tied if needed + # make sure token embedding weights are still tied if needed + model.tie_weights() # Set model in evaluation mode to deactivate DropOut modules by default model.eval() diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index e56a04369b..da5112e8b8 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -268,6 +268,88 @@ class EncoderDecoderMixin: ) self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) + def create_and_check_encoder_decoder_shared_weights( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + torch.manual_seed(0) + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + model.to(torch_device) + model.eval() + # load state dict copies weights but does not tie them + decoder_state_dict = model.decoder._modules[model.decoder.base_model_prefix].state_dict() + model.encoder.load_state_dict(decoder_state_dict, strict=False) + + torch.manual_seed(0) + tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(config, decoder_config) + config = EncoderDecoderConfig.from_encoder_decoder_configs( + tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True + ) + tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config) + tied_model.to(torch_device) + tied_model.eval() + + model_result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that models has less parameters + self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + # check that outputs are equal + self.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + tied_model.save_pretrained(tmpdirname) + tied_model = EncoderDecoderModel.from_pretrained(tmpdirname) + tied_model.to(torch_device) + tied_model.eval() + + # check that models has less parameters + self.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that outputs are equal + self.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + def test_encoder_decoder_model(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model(**input_ids_dict) @@ -296,6 +378,10 @@ class EncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) + def test_encoder_decoder_model_shared_weights(self): + input_ids_dict = self.prepare_config_and_inputs() + self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict) + @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model() @@ -480,3 +566,6 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): def get_pretrained_model(self): return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") + + def test_encoder_decoder_model_shared_weights(self): + pass diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index d79cc98164..53365005fc 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -14,6 +14,8 @@ # limitations under the License. +import copy +import tempfile import unittest from transformers import is_torch_available @@ -130,7 +132,7 @@ class T5ModelTester: # all items after square self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) - def create_and_check_t5_model( + def create_and_check_model( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5Model(config=config) @@ -156,7 +158,7 @@ class T5ModelTester: # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple self.parent.assertEqual(len(decoder_past[1][0]), 4) - def create_and_check_t5_with_lm_head( + def create_and_check_with_lm_head( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5ForConditionalGeneration(config=config).to(torch_device).eval() @@ -170,7 +172,7 @@ class T5ModelTester: self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["loss"].size(), ()) - def create_and_check_t5_decoder_model_past( + def create_and_check_decoder_model_past( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5Model(config=config).get_decoder().to(torch_device).eval() @@ -201,7 +203,7 @@ class T5ModelTester: # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - def create_and_check_t5_decoder_model_attention_mask_past( + def create_and_check_decoder_model_attention_mask_past( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5Model(config=config).get_decoder() @@ -245,7 +247,7 @@ class T5ModelTester: # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - def create_t5_and_check_t5_generate_with_past_key_value_states( + def create_and_check_generate_with_past_key_value_states( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5ForConditionalGeneration(config=config).to(torch_device).eval() @@ -257,13 +259,83 @@ class T5ModelTester: output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) - def create_and_check_t5_model_fp16_forward( + def create_and_check_model_fp16_forward( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5Model(config=config).to(torch_device).half().eval() output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] self.parent.assertFalse(torch.isnan(output).any().item()) + def create_and_check_encoder_decoder_shared_weights( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, + ): + for model_class in [T5Model, T5ForConditionalGeneration]: + torch.manual_seed(0) + model = model_class(config=config).to(torch_device).eval() + # load state dict copies weights but does not tie them + model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) + + torch.manual_seed(0) + tied_config = copy.deepcopy(config) + tied_config.tie_encoder_decoder = True + tied_model = model_class(config=tied_config).to(torch_device).eval() + + model_result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + tied_model.save_pretrained(tmpdirname) + tied_model = model_class.from_pretrained(tmpdirname) + tied_model.to(torch_device) + tied_model.eval() + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], + tied_model_result[0][0, :, random_slice_idx], + atol=1e-4, + ) + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs @@ -299,30 +371,34 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) - def test_t5_model(self): + def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_t5_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) def test_with_lm_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) - def test_t5_decoder_model_past(self): + def test_decoder_model_past(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs) + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) - def test_t5_decoder_model_past_with_attn_mask(self): + def test_decoder_model_past_with_attn_mask(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs) + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) - def test_t5_generate_with_past_key_value_states(self): + def test_generate_with_past_key_value_states(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs) + self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs) + + def test_encoder_decoder_shared_weights(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_t5_model_fp16_forward(self): + def test_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs) + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) @slow def test_model_from_pretrained(self): @@ -331,8 +407,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) def test_export_to_onnx(self): - import tempfile - config_and_inputs = self.model_tester.prepare_config_and_inputs() model = T5Model(config_and_inputs[0]).to(torch_device) with tempfile.TemporaryDirectory() as tmpdirname: