update doc for XLM and XLNet

This commit is contained in:
thomwolf
2019-07-15 11:36:50 +02:00
parent 0201d86015
commit 44c985facd
7 changed files with 459 additions and 561 deletions

View File

@@ -611,11 +611,11 @@ BERT_INPUTS_DOCSTRING = r"""
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask indices selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask indices selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@@ -714,7 +714,7 @@ class BertModel(BertPreTrainedModel):
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with two heads on top as done during the pre-training:
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForPreTraining(BertPreTrainedModel):
@@ -791,7 +791,7 @@ class BertForPreTraining(BertPreTrainedModel):
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with a `language modeling` head on top. """,
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
r"""
@@ -856,7 +856,7 @@ class BertForMaskedLM(BertPreTrainedModel):
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with a `next sentence prediction (classification)` head on top. """,
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForNextSentencePrediction(BertPreTrainedModel):
r"""
@@ -913,7 +913,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForSequenceClassification(BertPreTrainedModel):
@@ -981,7 +981,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with a multiple choice classification head on top (a linear layer on top of
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
BERT_START_DOCSTRING)
class BertForMultipleChoice(BertPreTrainedModel):
@@ -1016,11 +1016,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask indices selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask indices selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
@@ -1087,7 +1087,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with a token classification head on top (a linear layer on top of
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForTokenClassification(BertPreTrainedModel):
@@ -1154,17 +1154,17 @@ class BertForTokenClassification(BertPreTrainedModel):
return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForQuestionAnswering(BertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Position (index) of the start of the labelled span for computing the token classification loss.
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Position (index) of the end of the labelled span for computing the token classification loss.
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.