Merge branch 'master' into develop

This commit is contained in:
Thomas Wolf
2018-11-07 23:35:42 +01:00
committed by GitHub
2 changed files with 328 additions and 98 deletions

View File

@@ -388,10 +388,10 @@ class BertForSequenceClassification(nn.Module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(config.initializer_range)
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(config.initializer_range)
module.gamma.data.normal_(config.initializer_range)
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
@@ -438,10 +438,10 @@ class BertForQuestionAnswering(nn.Module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(config.initializer_range)
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(config.initializer_range)
module.gamma.data.normal_(config.initializer_range)
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
@@ -459,7 +459,7 @@ class BertForQuestionAnswering(nn.Module):
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) + 1
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)