From 2d07f945adfd41389b5dd45d85af37d404a09599 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Thu, 6 Jun 2019 17:10:24 +0200 Subject: [PATCH] fix error with torch.no_grad and loss computation --- hubconfs/bert_hubconf.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/hubconfs/bert_hubconf.py b/hubconfs/bert_hubconf.py index c7bcfbffb6..a547a33c22 100644 --- a/hubconfs/bert_hubconf.py +++ b/hubconfs/bert_hubconf.py @@ -238,8 +238,7 @@ def bertForSequenceClassification(*args, **kwargs): seq_classif_logits = model(tokens_tensor, segments_tensors) # Or get the sequence classification loss >>> labels = torch.tensor([1]) - >>> with torch.no_grad(): - seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels) + >>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels) """ model = BertForSequenceClassification.from_pretrained(*args, **kwargs) return model @@ -273,8 +272,7 @@ def bertForMultipleChoice(*args, **kwargs): multiple_choice_logits = model(tokens_tensor, segments_tensors) # Or get the multiple choice loss >>> labels = torch.tensor([1]) - >>> with torch.no_grad(): - multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels) + >>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels) """ model = BertForMultipleChoice.from_pretrained(*args, **kwargs) return model @@ -306,8 +304,7 @@ def bertForQuestionAnswering(*args, **kwargs): start_logits, end_logits = model(tokens_tensor, segments_tensors) # Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions >>> start_positions, end_positions = torch.tensor([12]), torch.tensor([14]) - >>> with torch.no_grad(): - multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions) + >>> multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions) """ model = BertForQuestionAnswering.from_pretrained(*args, **kwargs) return model @@ -344,8 +341,7 @@ def bertForTokenClassification(*args, **kwargs): classif_logits = model(tokens_tensor, segments_tensors) # Or get the token classification loss >>> labels = torch.tensor([[0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0]]) - >>> with torch.no_grad(): - classif_loss = model(tokens_tensor, segments_tensors, labels=labels) + >>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels) """ model = BertForTokenClassification.from_pretrained(*args, **kwargs) return model