From ee0308f79ded65dac82c53dfb03e9ff7f06aeee4 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Thu, 6 Jun 2019 17:30:49 +0200 Subject: [PATCH] fix typo --- hubconfs/bert_hubconf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hubconfs/bert_hubconf.py b/hubconfs/bert_hubconf.py index a547a33c22..0595bdeccb 100644 --- a/hubconfs/bert_hubconf.py +++ b/hubconfs/bert_hubconf.py @@ -238,7 +238,7 @@ def bertForSequenceClassification(*args, **kwargs): seq_classif_logits = model(tokens_tensor, segments_tensors) # Or get the sequence classification loss >>> labels = torch.tensor([1]) - >>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels) + >>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss """ model = BertForSequenceClassification.from_pretrained(*args, **kwargs) return model @@ -272,7 +272,7 @@ def bertForMultipleChoice(*args, **kwargs): multiple_choice_logits = model(tokens_tensor, segments_tensors) # Or get the multiple choice loss >>> labels = torch.tensor([1]) - >>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels) + >>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss """ model = BertForMultipleChoice.from_pretrained(*args, **kwargs) return model @@ -304,6 +304,7 @@ def bertForQuestionAnswering(*args, **kwargs): start_logits, end_logits = model(tokens_tensor, segments_tensors) # Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions >>> start_positions, end_positions = torch.tensor([12]), torch.tensor([14]) + # set model.train() before if training this loss >>> multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions) """ model = BertForQuestionAnswering.from_pretrained(*args, **kwargs) @@ -341,7 +342,7 @@ def bertForTokenClassification(*args, **kwargs): classif_logits = model(tokens_tensor, segments_tensors) # Or get the token classification loss >>> labels = torch.tensor([[0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0]]) - >>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels) + >>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss """ model = BertForTokenClassification.from_pretrained(*args, **kwargs) return model