don't initialize the output embeddings if we're going to tie them to input embeddings (#28192)
* test that tied output embeddings aren't initialized on load * don't initialize the output embeddings if we're going to tie them to the input embeddings
This commit is contained in:
@@ -483,6 +483,40 @@ class ModelTesterMixin:
|
||||
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
||||
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_fast_init_tied_embeddings(self):
|
||||
class MyClass(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
_tied_weights_keys = ["output_embeddings.weight"]
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config if config is not None else PretrainedConfig())
|
||||
self.input_embeddings = nn.Embedding(10, 10)
|
||||
self.output_embeddings = nn.Linear(10, 10, bias=False)
|
||||
self.tie_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.output_embeddings
|
||||
|
||||
def set_output_embeddings(self, output_embeddings):
|
||||
self.output_embeddings = output_embeddings
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_input_embeddings(self, input_embeddings):
|
||||
self.input_embeddings = input_embeddings
|
||||
|
||||
def _init_weights(self, module):
|
||||
if module is self.output_embeddings:
|
||||
raise ValueError("unnecessarily initialized tied output embedding!")
|
||||
|
||||
model = MyClass()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
# throws if it initializes the tied output_embeddings
|
||||
MyClass.from_pretrained(tmpdirname)
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
|
||||
Reference in New Issue
Block a user