Refactor checkpoint name in BERT and MobileBERT (#10424)
* Refactor checkpoint name in BERT and MobileBERT * Add option to check copies * Add QuestionAnswering * Add last models * Make black happy
This commit is contained in:
@@ -58,6 +58,7 @@ from .configuration_bert import BertConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
||||
_CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
@@ -862,7 +863,7 @@ class BertModel(BertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1273,7 +1274,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1468,7 +1469,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1552,7 +1553,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1647,7 +1648,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1737,7 +1738,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ from .configuration_mobilebert import MobileBertConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/mobilebert-uncased"
|
||||
_CONFIG_FOR_DOC = "MobileBertConfig"
|
||||
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
|
||||
|
||||
@@ -818,7 +819,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/mobilebert-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1033,7 +1034,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/mobilebert-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1204,20 +1205,22 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
""",
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
|
||||
class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.mobilebert = MobileBertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/mobilebert-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1253,7 +1256,9 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
@@ -1286,6 +1291,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
""",
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
|
||||
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
@@ -1302,7 +1308,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/mobilebert-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1403,10 +1409,11 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/mobilebert-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.forward with Bert->MobileBert all-casing
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -1481,6 +1488,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
""",
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
|
||||
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
@@ -1498,7 +1506,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/mobilebert-uncased",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user