don't initialize the output embeddings if we're going to tie them to input embeddings (#28192)
* test that tied output embeddings aren't initialized on load * don't initialize the output embeddings if we're going to tie them to the input embeddings
This commit is contained in:
@@ -3746,6 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
_loaded_keys = loaded_keys
|
||||
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
|
||||
# if we're about to tie the output embeds to the input embeds we don't need to init them
|
||||
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
|
||||
output_embeddings = model.get_output_embeddings()
|
||||
if output_embeddings is not None:
|
||||
output_embeddings._is_hf_initialized = True
|
||||
else:
|
||||
not_initialized_submodules = dict(model.named_modules())
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
|
||||
@@ -483,6 +483,40 @@ class ModelTesterMixin:
|
||||
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
||||
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_fast_init_tied_embeddings(self):
|
||||
class MyClass(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
_tied_weights_keys = ["output_embeddings.weight"]
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config if config is not None else PretrainedConfig())
|
||||
self.input_embeddings = nn.Embedding(10, 10)
|
||||
self.output_embeddings = nn.Linear(10, 10, bias=False)
|
||||
self.tie_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.output_embeddings
|
||||
|
||||
def set_output_embeddings(self, output_embeddings):
|
||||
self.output_embeddings = output_embeddings
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_input_embeddings(self, input_embeddings):
|
||||
self.input_embeddings = input_embeddings
|
||||
|
||||
def _init_weights(self, module):
|
||||
if module is self.output_embeddings:
|
||||
raise ValueError("unnecessarily initialized tied output embedding!")
|
||||
|
||||
model = MyClass()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
# throws if it initializes the tied output_embeddings
|
||||
MyClass.from_pretrained(tmpdirname)
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
|
||||
Reference in New Issue
Block a user