special edition script

This commit is contained in:
thomwolf
2018-11-03 19:06:15 +01:00
parent 25f73add07
commit 04287a4d68
3 changed files with 108 additions and 4 deletions

View File

@@ -482,9 +482,14 @@ class BertForQuestionAnswering(nn.Module):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Embedding)):
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal
# Slight difference here with the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(config.initializer_range)
elif isinstance(m, BERTLayerNorm):
m.beta.data.normal_(config.initializer_range)
m.gamme.data.normal_(config.initializer_range)
if isinstance(m, nn.Linear):
m.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):