improve docstrings and fix new token classification model

This commit is contained in:
thomwolf
2018-11-30 22:55:26 +01:00
parent ed302a73f4
commit d787c6be8c

View File

@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel):
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block, to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper). input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel):
sentence classification loss. sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`: if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising Outputs a tuple comprising
- the masked language modeling logits, and - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits. - the next sentence classification logits of shape [batch_size, 2].
Example usage: Example usage:
```python ```python
@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel):
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
return total_loss return total_loss
@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if `masked_lm_labels` is `None`: if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss. Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`: if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits. Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage: Example usage:
```python ```python
@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
Outputs the total_loss which is the sum of the masked language modeling loss and the next Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss. sentence classification loss.
if `next_sentence_label` is `None`: if `next_sentence_label` is `None`:
Outputs the next sentence classification logits. Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage: Example usage:
```python ```python
@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if `labels` is not `None`: if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels. Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`: if `labels` is `None`:
Outputs the classification logits. Outputs the classification logits of shape [batch_size, num_labels].
Example usage: Example usage:
```python ```python
@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits return loss
else: else:
return logits return logits
@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if `labels` is not `None`: if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels. Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`: if `labels` is `None`:
Outputs the classification logits. Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage: Example usage:
```python ```python
@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits return loss
else: else:
return logits return logits
@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel):
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`: if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens. position tokens of shape [batch_size, sequence_length].
Example usage: Example usage:
```python ```python