improve docstrings and fix new token classification model
This commit is contained in:
@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel):
|
|||||||
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
||||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||||
to the last attention block,
|
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||||
classifier pretrained on top of the hidden state associated to the first character of the
|
classifier pretrained on top of the hidden state associated to the first character of the
|
||||||
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||||
@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel):
|
|||||||
sentence classification loss.
|
sentence classification loss.
|
||||||
if `masked_lm_labels` or `next_sentence_label` is `None`:
|
if `masked_lm_labels` or `next_sentence_label` is `None`:
|
||||||
Outputs a tuple comprising
|
Outputs a tuple comprising
|
||||||
- the masked language modeling logits, and
|
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
||||||
- the next sentence classification logits.
|
- the next sentence classification logits of shape [batch_size, 2].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel):
|
|||||||
|
|
||||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||||
total_loss = masked_lm_loss + next_sentence_loss
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
return total_loss
|
return total_loss
|
||||||
@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel):
|
|||||||
if `masked_lm_labels` is `None`:
|
if `masked_lm_labels` is `None`:
|
||||||
Outputs the masked language modeling loss.
|
Outputs the masked language modeling loss.
|
||||||
if `masked_lm_labels` is `None`:
|
if `masked_lm_labels` is `None`:
|
||||||
Outputs the masked language modeling logits.
|
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
|||||||
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
||||||
sentence classification loss.
|
sentence classification loss.
|
||||||
if `next_sentence_label` is `None`:
|
if `next_sentence_label` is `None`:
|
||||||
Outputs the next sentence classification logits.
|
Outputs the next sentence classification logits of shape [batch_size, 2].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
|||||||
if `labels` is not `None`:
|
if `labels` is not `None`:
|
||||||
Outputs the CrossEntropy classification loss of the output with the labels.
|
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||||
if `labels` is `None`:
|
if `labels` is `None`:
|
||||||
Outputs the classification logits.
|
Outputs the classification logits of shape [batch_size, num_labels].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, logits
|
return loss
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel):
|
|||||||
if `labels` is not `None`:
|
if `labels` is not `None`:
|
||||||
Outputs the CrossEntropy classification loss of the output with the labels.
|
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||||
if `labels` is `None`:
|
if `labels` is `None`:
|
||||||
Outputs the classification logits.
|
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, logits
|
return loss
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel):
|
|||||||
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
||||||
if `start_positions` or `end_positions` is `None`:
|
if `start_positions` or `end_positions` is `None`:
|
||||||
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
||||||
position tokens.
|
position tokens of shape [batch_size, sequence_length].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
|
|||||||
Reference in New Issue
Block a user