default in __init__s for classification BERT models (#650)
This commit is contained in:
@@ -980,7 +980,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, num_labels):
|
def __init__(self, config, num_labels=2):
|
||||||
super(BertForSequenceClassification, self).__init__(config)
|
super(BertForSequenceClassification, self).__init__(config)
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
@@ -1045,7 +1045,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, num_choices):
|
def __init__(self, config, num_choices=2):
|
||||||
super(BertForMultipleChoice, self).__init__(config)
|
super(BertForMultipleChoice, self).__init__(config)
|
||||||
self.num_choices = num_choices
|
self.num_choices = num_choices
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
@@ -1115,7 +1115,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, num_labels):
|
def __init__(self, config, num_labels=2):
|
||||||
super(BertForTokenClassification, self).__init__(config)
|
super(BertForTokenClassification, self).__init__(config)
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user