don't initialize the output embeddings if we're going to tie them to input embeddings (#28192)
* test that tied output embeddings aren't initialized on load * don't initialize the output embeddings if we're going to tie them to the input embeddings
This commit is contained in:
@@ -3746,6 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
_loaded_keys = loaded_keys
|
_loaded_keys = loaded_keys
|
||||||
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
|
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
|
||||||
|
# if we're about to tie the output embeds to the input embeds we don't need to init them
|
||||||
|
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
|
||||||
|
output_embeddings = model.get_output_embeddings()
|
||||||
|
if output_embeddings is not None:
|
||||||
|
output_embeddings._is_hf_initialized = True
|
||||||
else:
|
else:
|
||||||
not_initialized_submodules = dict(model.named_modules())
|
not_initialized_submodules = dict(model.named_modules())
|
||||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||||
|
|||||||
@@ -483,6 +483,40 @@ class ModelTesterMixin:
|
|||||||
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
||||||
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_fast_init_tied_embeddings(self):
|
||||||
|
class MyClass(PreTrainedModel):
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
_tied_weights_keys = ["output_embeddings.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super().__init__(config if config is not None else PretrainedConfig())
|
||||||
|
self.input_embeddings = nn.Embedding(10, 10)
|
||||||
|
self.output_embeddings = nn.Linear(10, 10, bias=False)
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.output_embeddings
|
||||||
|
|
||||||
|
def set_output_embeddings(self, output_embeddings):
|
||||||
|
self.output_embeddings = output_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.input_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, input_embeddings):
|
||||||
|
self.input_embeddings = input_embeddings
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if module is self.output_embeddings:
|
||||||
|
raise ValueError("unnecessarily initialized tied output embedding!")
|
||||||
|
|
||||||
|
model = MyClass()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
# throws if it initializes the tied output_embeddings
|
||||||
|
MyClass.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
def test_save_load_fast_init_to_base(self):
|
def test_save_load_fast_init_to_base(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if config.__class__ not in MODEL_MAPPING:
|
if config.__class__ not in MODEL_MAPPING:
|
||||||
|
|||||||
Reference in New Issue
Block a user