From bfd6f6b257f2d4857f65bbcd6cb3487123fe848f Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Wed, 17 Apr 2019 13:39:46 -0700 Subject: [PATCH] fix from_pretrained positional args --- hubconf.py | 214 ++++++++++++++-------------- pytorch_pretrained_bert/modeling.py | 10 +- 2 files changed, 115 insertions(+), 109 deletions(-) diff --git a/hubconf.py b/hubconf.py index b2e44af5d9..755e181d20 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,14 +1,55 @@ from pytorch_pretrained_bert.tokenization import BertTokenizer -from pytorch_pretrained_bert.modeling import (BertForNextSentencePrediction, - BertForMaskedLM, - BertForMultipleChoice, - BertForPreTraining, - BertForQuestionAnswering, - BertForSequenceClassification, - ) +from pytorch_pretrained_bert.modeling import ( + BertModel, + BertForNextSentencePrediction, + BertForMaskedLM, + BertForMultipleChoice, + BertForPreTraining, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, + ) dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] +# A lot of models share the same param doc. Use a decorator +# to save typing +bert_docstring = """ + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining + instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow + checkpoint + cache_dir: an optional path to a folder in which the pre-trained models + will be cached. + state_dict: an optional state dictionnary + (collections.OrderedDict object) to use instead of Google + pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) +""" + + +def _append_from_pretrained_docstring(docstr): + def docstring_decorator(fn): + fn.__doc__ = fn.__doc__ + docstr + return fn + return docstring_decorator + def bertTokenizer(*args, **kwargs): """ @@ -43,7 +84,7 @@ def bertTokenizer(*args, **kwargs): Example: >>> sentence = 'Hello, World!' - >>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'BertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) + >>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> toks = tokenizer.tokenize(sentence) ['Hello', '##,', 'World', '##!'] >>> ids = tokenizer.convert_tokens_to_ids(toks) @@ -53,135 +94,94 @@ def bertTokenizer(*args, **kwargs): return tokenizer +@_append_from_pretrained_docstring(bert_docstring) +def bertModel(*args, **kwargs): + """ + BertModel is the basic BERT Transformer model with a layer of summed token, + position and sequence embeddings followed by a series of identical + self-attention blocks (12 for BERT-base, 24 for BERT-large). + """ + model = BertModel.from_pretrained(*args, **kwargs) + return model + + +@_append_from_pretrained_docstring(bert_docstring) def bertForNextSentencePrediction(*args, **kwargs): - """BERT model with next sentence prediction head. - This module comprises the BERT model followed by the next sentence classification head. - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) + """ + BERT model with next sentence prediction head. + This module comprises the BERT model followed by the next sentence + classification head. """ model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs) return model +@_append_from_pretrained_docstring(bert_docstring) def bertForPreTraining(*args, **kwargs): - """BERT model with pre-training heads. - This module comprises the BERT model followed by the two pre-training heads: + """ + BERT model with pre-training heads. + This module comprises the BERT model followed by the two pre-training heads - the masked language modeling head, and - the next sentence classification head. - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ model = BertForPreTraining.from_pretrained(*args, **kwargs) return model +@_append_from_pretrained_docstring(bert_docstring) def bertForMaskedLM(*args, **kwargs): """ - BertForMaskedLM includes the BertModel Transformer followed by the (possibly) - pre-trained masked language modeling head. - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) + BertForMaskedLM includes the BertModel Transformer followed by the + (possibly) pre-trained masked language modeling head. """ model = BertForMaskedLM.from_pretrained(*args, **kwargs) return model -#def bertForSequenceClassification(*args, **kwargs): -# model = BertForSequenceClassification.from_pretrained(*args, **kwargs) -# return model +@_append_from_pretrained_docstring(bert_docstring) +def bertForSequenceClassification(*args, **kwargs): + """ + BertForSequenceClassification is a fine-tuning model that includes + BertModel and a sequence-level (sequence or pair of sequences) classifier + on top of the BertModel. + + The sequence-level classifier is a linear layer that takes as input the + last hidden state of the first character in the input sequence + (see Figures 3a and 3b in the BERT paper). + """ + model = BertForSequenceClassification.from_pretrained(*args, **kwargs) + return model -#def bertForMultipleChoice(*args, **kwargs): -# model = BertForMultipleChoice.from_pretrained(*args, **kwargs) -# return model +@_append_from_pretrained_docstring(bert_docstring) +def bertForMultipleChoice(*args, **kwargs): + """ + BertForMultipleChoice is a fine-tuning model that includes BertModel and a + linear layer on top of the BertModel. + """ + model = BertForMultipleChoice.from_pretrained(*args, **kwargs) + return model +@_append_from_pretrained_docstring(bert_docstring) def bertForQuestionAnswering(*args, **kwargs): """ - BertForQuestionAnswering is a fine-tuning model that includes BertModel with - a token-level classifiers on top of the full sequence of last hidden states. - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) + BertForQuestionAnswering is a fine-tuning model that includes BertModel + with a token-level classifiers on top of the full sequence of last hidden + states. """ model = BertForQuestionAnswering.from_pretrained(*args, **kwargs) return model +@_append_from_pretrained_docstring(bert_docstring) +def bertForTokenClassification(*args, **kwargs): + """ + BertForTokenClassification is a fine-tuning model that includes BertModel + and a token-level classifier on top of the BertModel. + The token-level classifier is a linear layer that takes as input the last + hidden state of the sequence. + """ + model = BertForTokenClassification.from_pretrained(*args, **kwargs) + return model diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 2736e34d7f..9c9b031970 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -519,8 +519,7 @@ class BertPreTrainedModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, - from_tf=False, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -547,6 +546,13 @@ class BertPreTrainedModel(nn.Module): *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ + state_dict = kwargs.get('state_dict', None) + kwargs.pop('state_dict', None) + cache_dir = kwargs.get('cache_dir', None) + kwargs.pop('cache_dir', None) + from_tf = kwargs.get('from_tf', False) + kwargs.pop('from_tf', None) + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] else: