From d787c6be8c5fc88190f5723745f5e05d15f6b30c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Nov 2018 22:55:26 +0100 Subject: [PATCH] improve docstrings and fix new token classification model --- pytorch_pretrained_bert/modeling.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index e8ad26a1c6..13666c86df 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel): of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block, + to the last attention block of shape [batch_size, sequence_length, hidden_size], `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). @@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel): sentence classification loss. if `masked_lm_labels` or `next_sentence_label` is `None`: Outputs a tuple comprising - - the masked language modeling logits, and - - the next sentence classification logits. + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. Example usage: ```python @@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel): if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1)) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) total_loss = masked_lm_loss + next_sentence_loss return total_loss @@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel): if `masked_lm_labels` is `None`: Outputs the masked language modeling loss. if `masked_lm_labels` is `None`: - Outputs the masked language modeling logits. + Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. Example usage: ```python @@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel): Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss. if `next_sentence_label` is `None`: - Outputs the next sentence classification logits. + Outputs the next sentence classification logits of shape [batch_size, 2]. Example usage: ```python @@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel): if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: - Outputs the classification logits. + Outputs the classification logits of shape [batch_size, num_labels]. Example usage: ```python @@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel): if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss, logits + return loss else: return logits @@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel): if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: - Outputs the classification logits. + Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. Example usage: ```python @@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel): if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss, logits + return loss else: return logits @@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel): Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. if `start_positions` or `end_positions` is `None`: Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end - position tokens. + position tokens of shape [batch_size, sequence_length]. Example usage: ```python