default in __init__s for classification BERT models (#650)
This commit is contained in:
@@ -980,7 +980,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels):
|
||||
def __init__(self, config, num_labels=2):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
@@ -1045,7 +1045,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices):
|
||||
def __init__(self, config, num_choices=2):
|
||||
super(BertForMultipleChoice, self).__init__(config)
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel(config)
|
||||
@@ -1115,7 +1115,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels):
|
||||
def __init__(self, config, num_labels=2):
|
||||
super(BertForTokenClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
|
||||
Reference in New Issue
Block a user