From 96592b544bb460085bb5e2522070254849e82350 Mon Sep 17 00:00:00 2001 From: Victor SANH Date: Thu, 30 May 2019 15:53:13 -0400 Subject: [PATCH] default in __init__s for classification BERT models (#650) --- pytorch_pretrained_bert/modeling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index d1c4c07c98..ac6c337405 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -980,7 +980,7 @@ class BertForSequenceClassification(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels): + def __init__(self, config, num_labels=2): super(BertForSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) @@ -1045,7 +1045,7 @@ class BertForMultipleChoice(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_choices): + def __init__(self, config, num_choices=2): super(BertForMultipleChoice, self).__init__(config) self.num_choices = num_choices self.bert = BertModel(config) @@ -1115,7 +1115,7 @@ class BertForTokenClassification(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels): + def __init__(self, config, num_labels=2): super(BertForTokenClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config)