This commit is contained in:
thomwolf
2019-01-07 13:37:55 +01:00
7 changed files with 703 additions and 46 deletions

View File

@@ -439,8 +439,8 @@ class PreTrainedModel(nn.Module):
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@@ -456,7 +456,9 @@ class PreTrainedModel(nn.Module):
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-base-multilingual`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
@@ -728,7 +730,7 @@ class BertForMaskedLM(PreTrainedModel):
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `masked_lm_labels` is `None`:
if `masked_lm_labels` is not `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
@@ -1035,15 +1037,7 @@ class BertForQuestionAnswering(PreTrainedModel):
the sequence output that computes start_logits and end_logits
Params:
`config`: either
- a BertConfig class instance with the configuration to build a new model, or
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-base-multilingual`
. `bert-base-chinese`
The pre-trained model will be downloaded and cached if needed.
`config`: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]