[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)
* start adding tie encoder to decoder functionality * finish model tying * make style * Apply suggestions from code review * fix t5 list including cross attention * apply sams suggestions * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add max depth break point Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ab42d74850
commit
fe0b85e77a
@@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_configs(
|
||||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig
|
||||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
||||
) -> PretrainedConfig:
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
|
||||
@@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig):
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
|
||||
@@ -58,6 +58,8 @@ class PretrainedConfig(object):
|
||||
Whether the model is used as decoder or not (in which case it's used as an encoder).
|
||||
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
|
||||
tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
|
||||
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names.
|
||||
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
|
||||
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
|
||||
of heads to prune in said layer.
|
||||
@@ -153,6 +155,7 @@ class PretrainedConfig(object):
|
||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
||||
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
||||
|
||||
# Parameters for sequence generation
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
|
||||
@@ -71,9 +71,17 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
self.encoder.get_output_embeddings() is None
|
||||
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
|
||||
|
||||
# tie encoder, decoder weights if config set accordingly
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
# for now no weights tying in encoder-decoder
|
||||
pass
|
||||
# tie encoder & decoder if needed
|
||||
if self.config.tie_encoder_decoder:
|
||||
# tie encoder and decoder base model
|
||||
decoder_base_model_prefix = self.decoder.base_model_prefix
|
||||
self._tie_encoder_decoder_weights(
|
||||
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
|
||||
)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
@@ -122,7 +130,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments.
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``).
|
||||
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter
|
||||
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter
|
||||
Behave differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||
|
||||
Examples::
|
||||
|
||||
@@ -144,6 +156,12 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# remove encoder, decoder kwargs from kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
del kwargs["encoder_" + key]
|
||||
for key in kwargs_decoder.keys():
|
||||
del kwargs["decoder_" + key]
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
@@ -184,7 +202,9 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
return cls(encoder=encoder, decoder=decoder)
|
||||
# instantiate config with corresponding kwargs
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||
return cls(encoder=encoder, decoder=decoder, config=config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -887,10 +887,12 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.is_encoder_decoder = False
|
||||
self.decoder = T5Stack(decoder_config, self.shared)
|
||||
|
||||
self.init_weights()
|
||||
@@ -1040,10 +1042,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.is_encoder_decoder = False
|
||||
self.decoder = T5Stack(decoder_config, self.shared)
|
||||
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
@@ -416,6 +416,77 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if output_embeddings is not None:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||
|
||||
@staticmethod
|
||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
|
||||
|
||||
def tie_encoder_to_decoder_recursively(
|
||||
decoder_pointer: nn.Module,
|
||||
encoder_pointer: nn.Module,
|
||||
module_name: str,
|
||||
uninitialized_encoder_weights: List[str],
|
||||
depth=0,
|
||||
):
|
||||
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
||||
encoder_pointer, nn.Module
|
||||
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
|
||||
if hasattr(decoder_pointer, "weight"):
|
||||
assert hasattr(encoder_pointer, "weight")
|
||||
encoder_pointer.weight = decoder_pointer.weight
|
||||
if hasattr(decoder_pointer, "bias"):
|
||||
assert hasattr(encoder_pointer, "bias")
|
||||
encoder_pointer.bias = decoder_pointer.bias
|
||||
return
|
||||
|
||||
encoder_modules = encoder_pointer._modules
|
||||
decoder_modules = decoder_pointer._modules
|
||||
if len(decoder_modules) > 0:
|
||||
assert (
|
||||
len(encoder_modules) > 0
|
||||
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||
|
||||
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
||||
encoder_layer_pos = 0
|
||||
for name, module in decoder_modules.items():
|
||||
if name.isdigit():
|
||||
encoder_name = str(int(name) + encoder_layer_pos)
|
||||
decoder_name = name
|
||||
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
|
||||
# this can happen if the name corresponds to the position in a list module list of layers
|
||||
# in this case the decoder has added a cross-attention that the encoder does not have
|
||||
# thus skip this step and substract one layer pos from encoder
|
||||
encoder_layer_pos -= 1
|
||||
continue
|
||||
elif name not in encoder_modules:
|
||||
continue
|
||||
elif depth > 500:
|
||||
raise ValueError(
|
||||
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
|
||||
)
|
||||
else:
|
||||
decoder_name = encoder_name = name
|
||||
tie_encoder_to_decoder_recursively(
|
||||
decoder_modules[decoder_name],
|
||||
encoder_modules[encoder_name],
|
||||
module_name + "/" + name,
|
||||
uninitialized_encoder_weights,
|
||||
depth=depth + 1,
|
||||
)
|
||||
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
||||
|
||||
uninitialized_encoder_weights += list(all_encoder_weights)
|
||||
|
||||
# tie weights recursively
|
||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
|
||||
if len(uninitialized_encoder_weights) > 0:
|
||||
logger.warning(
|
||||
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
|
||||
)
|
||||
|
||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||
""" Tie or clone module weights depending of whether we are using TorchScript or not
|
||||
"""
|
||||
@@ -894,7 +965,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user