Fix embeddings for PyTorch 1.8 (#10549)

* Fix embeddings for PyTorch 1.8

* Try with PyTorch 1.8.0

* Fix embeddings init

* Fix copies

* Typo

* More typos
This commit is contained in:
Sylvain Gugger
2021-03-05 16:18:48 -05:00
committed by GitHub
parent 3e056c1003
commit 7da995c00c
27 changed files with 171 additions and 78 deletions

View File

@@ -704,15 +704,19 @@ class BertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass