Fix 29807, sinusoidal positional encodings overwritten by post_init() (#29813)
* Check for requires_grad when initing weights * Add unit test * Move sinusoidal positional encoding generation after post_init() * Add modules to skip init list * Move create_sinusoidal_embeddings to _init_weights
This commit is contained in:
committed by
GitHub
parent
cefb819f7a
commit
a81cf9ee90
@@ -106,10 +106,6 @@ class Embeddings(nn.Module):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
|
||||
if config.sinusoidal_pos_embds:
|
||||
create_sinusoidal_embeddings(
|
||||
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
|
||||
)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
@@ -634,6 +630,10 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
|
||||
create_sinusoidal_embeddings(
|
||||
self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight
|
||||
)
|
||||
|
||||
|
||||
DISTILBERT_START_DOCSTRING = r"""
|
||||
|
||||
@@ -37,6 +37,7 @@ if is_torch_available():
|
||||
DistilBertForTokenClassification,
|
||||
DistilBertModel,
|
||||
)
|
||||
from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings
|
||||
|
||||
|
||||
class DistilBertModelTester(object):
|
||||
@@ -238,6 +239,15 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
|
||||
|
||||
def test_distilbert_model_with_sinusoidal_encodings(self):
|
||||
config = DistilBertConfig(sinusoidal_pos_embds=True)
|
||||
model = DistilBertModel(config=config)
|
||||
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.dim), dtype=torch.float32)
|
||||
_create_sinusoidal_embeddings(config.max_position_embeddings, config.dim, sinusoidal_pos_embds)
|
||||
self.model_tester.parent.assertTrue(
|
||||
torch.equal(model.embeddings.position_embeddings.weight, sinusoidal_pos_embds)
|
||||
)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user