improve docstrings and fix new token classification model
This commit is contained in:
@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel):
|
||||
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||
to the last attention block,
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pretrained on top of the hidden state associated to the first character of the
|
||||
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||
@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel):
|
||||
sentence classification loss.
|
||||
if `masked_lm_labels` or `next_sentence_label` is `None`:
|
||||
Outputs a tuple comprising
|
||||
- the masked language modeling logits, and
|
||||
- the next sentence classification logits.
|
||||
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
||||
- the next sentence classification logits of shape [batch_size, 2].
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel):
|
||||
|
||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
return total_loss
|
||||
@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel):
|
||||
if `masked_lm_labels` is `None`:
|
||||
Outputs the masked language modeling loss.
|
||||
if `masked_lm_labels` is `None`:
|
||||
Outputs the masked language modeling logits.
|
||||
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
||||
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
||||
sentence classification loss.
|
||||
if `next_sentence_label` is `None`:
|
||||
Outputs the next sentence classification logits.
|
||||
Outputs the next sentence classification logits of shape [batch_size, 2].
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||
if `labels` is not `None`:
|
||||
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||
if `labels` is `None`:
|
||||
Outputs the classification logits.
|
||||
Outputs the classification logits of shape [batch_size, num_labels].
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss, logits
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel):
|
||||
if `labels` is not `None`:
|
||||
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||
if `labels` is `None`:
|
||||
Outputs the classification logits.
|
||||
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel):
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss, logits
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
|
||||
@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel):
|
||||
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
||||
if `start_positions` or `end_positions` is `None`:
|
||||
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
||||
position tokens.
|
||||
position tokens of shape [batch_size, sequence_length].
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
|
||||
Reference in New Issue
Block a user