default in __init__s for classification BERT models (#650)

This commit is contained in:
Victor SANH
2019-05-30 15:53:13 -04:00
committed by GitHub
parent 4cda86b08f
commit 96592b544b

View File

@@ -980,7 +980,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels): def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config)
@@ -1045,7 +1045,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices): def __init__(self, config, num_choices=2):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config) self.bert = BertModel(config)
@@ -1115,7 +1115,7 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels): def __init__(self, config, num_labels=2):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config)