fix from_pretrained positional args
This commit is contained in:
214
hubconf.py
214
hubconf.py
@@ -1,14 +1,55 @@
|
|||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import (BertForNextSentencePrediction,
|
from pytorch_pretrained_bert.modeling import (
|
||||||
BertForMaskedLM,
|
BertModel,
|
||||||
BertForMultipleChoice,
|
BertForNextSentencePrediction,
|
||||||
BertForPreTraining,
|
BertForMaskedLM,
|
||||||
BertForQuestionAnswering,
|
BertForMultipleChoice,
|
||||||
BertForSequenceClassification,
|
BertForPreTraining,
|
||||||
)
|
BertForQuestionAnswering,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BertForTokenClassification,
|
||||||
|
)
|
||||||
|
|
||||||
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
|
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
|
||||||
|
|
||||||
|
# A lot of models share the same param doc. Use a decorator
|
||||||
|
# to save typing
|
||||||
|
bert_docstring = """
|
||||||
|
Params:
|
||||||
|
pretrained_model_name_or_path: either:
|
||||||
|
- a str with the name of a pre-trained model to load
|
||||||
|
. `bert-base-uncased`
|
||||||
|
. `bert-large-uncased`
|
||||||
|
. `bert-base-cased`
|
||||||
|
. `bert-large-cased`
|
||||||
|
. `bert-base-multilingual-uncased`
|
||||||
|
. `bert-base-multilingual-cased`
|
||||||
|
. `bert-base-chinese`
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `bert_config.json` a configuration file for the model
|
||||||
|
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining
|
||||||
|
instance
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `bert_config.json` a configuration file for the model
|
||||||
|
. `model.chkpt` a TensorFlow checkpoint
|
||||||
|
from_tf: should we load the weights from a locally saved TensorFlow
|
||||||
|
checkpoint
|
||||||
|
cache_dir: an optional path to a folder in which the pre-trained models
|
||||||
|
will be cached.
|
||||||
|
state_dict: an optional state dictionnary
|
||||||
|
(collections.OrderedDict object) to use instead of Google
|
||||||
|
pre-trained models
|
||||||
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
|
(ex: num_labels for BertForSequenceClassification)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _append_from_pretrained_docstring(docstr):
|
||||||
|
def docstring_decorator(fn):
|
||||||
|
fn.__doc__ = fn.__doc__ + docstr
|
||||||
|
return fn
|
||||||
|
return docstring_decorator
|
||||||
|
|
||||||
|
|
||||||
def bertTokenizer(*args, **kwargs):
|
def bertTokenizer(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -43,7 +84,7 @@ def bertTokenizer(*args, **kwargs):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> sentence = 'Hello, World!'
|
>>> sentence = 'Hello, World!'
|
||||||
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'BertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
|
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
|
||||||
>>> toks = tokenizer.tokenize(sentence)
|
>>> toks = tokenizer.tokenize(sentence)
|
||||||
['Hello', '##,', 'World', '##!']
|
['Hello', '##,', 'World', '##!']
|
||||||
>>> ids = tokenizer.convert_tokens_to_ids(toks)
|
>>> ids = tokenizer.convert_tokens_to_ids(toks)
|
||||||
@@ -53,135 +94,94 @@ def bertTokenizer(*args, **kwargs):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
|
def bertModel(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
BertModel is the basic BERT Transformer model with a layer of summed token,
|
||||||
|
position and sequence embeddings followed by a series of identical
|
||||||
|
self-attention blocks (12 for BERT-base, 24 for BERT-large).
|
||||||
|
"""
|
||||||
|
model = BertModel.from_pretrained(*args, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
def bertForNextSentencePrediction(*args, **kwargs):
|
def bertForNextSentencePrediction(*args, **kwargs):
|
||||||
"""BERT model with next sentence prediction head.
|
"""
|
||||||
This module comprises the BERT model followed by the next sentence classification head.
|
BERT model with next sentence prediction head.
|
||||||
Params:
|
This module comprises the BERT model followed by the next sentence
|
||||||
pretrained_model_name_or_path: either:
|
classification head.
|
||||||
- a str with the name of a pre-trained model to load selected in the list of:
|
|
||||||
. `bert-base-uncased`
|
|
||||||
. `bert-large-uncased`
|
|
||||||
. `bert-base-cased`
|
|
||||||
. `bert-large-cased`
|
|
||||||
. `bert-base-multilingual-uncased`
|
|
||||||
. `bert-base-multilingual-cased`
|
|
||||||
. `bert-base-chinese`
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `model.chkpt` a TensorFlow checkpoint
|
|
||||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
|
||||||
"""
|
"""
|
||||||
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
|
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
def bertForPreTraining(*args, **kwargs):
|
def bertForPreTraining(*args, **kwargs):
|
||||||
"""BERT model with pre-training heads.
|
"""
|
||||||
This module comprises the BERT model followed by the two pre-training heads:
|
BERT model with pre-training heads.
|
||||||
|
This module comprises the BERT model followed by the two pre-training heads
|
||||||
- the masked language modeling head, and
|
- the masked language modeling head, and
|
||||||
- the next sentence classification head.
|
- the next sentence classification head.
|
||||||
Params:
|
|
||||||
pretrained_model_name_or_path: either:
|
|
||||||
- a str with the name of a pre-trained model to load selected in the list of:
|
|
||||||
. `bert-base-uncased`
|
|
||||||
. `bert-large-uncased`
|
|
||||||
. `bert-base-cased`
|
|
||||||
. `bert-large-cased`
|
|
||||||
. `bert-base-multilingual-uncased`
|
|
||||||
. `bert-base-multilingual-cased`
|
|
||||||
. `bert-base-chinese`
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `model.chkpt` a TensorFlow checkpoint
|
|
||||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model = BertForPreTraining.from_pretrained(*args, **kwargs)
|
model = BertForPreTraining.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
def bertForMaskedLM(*args, **kwargs):
|
def bertForMaskedLM(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
BertForMaskedLM includes the BertModel Transformer followed by the (possibly)
|
BertForMaskedLM includes the BertModel Transformer followed by the
|
||||||
pre-trained masked language modeling head.
|
(possibly) pre-trained masked language modeling head.
|
||||||
Params:
|
|
||||||
pretrained_model_name_or_path: either:
|
|
||||||
- a str with the name of a pre-trained model to load selected in the list of:
|
|
||||||
. `bert-base-uncased`
|
|
||||||
. `bert-large-uncased`
|
|
||||||
. `bert-base-cased`
|
|
||||||
. `bert-large-cased`
|
|
||||||
. `bert-base-multilingual-uncased`
|
|
||||||
. `bert-base-multilingual-cased`
|
|
||||||
. `bert-base-chinese`
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `model.chkpt` a TensorFlow checkpoint
|
|
||||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
|
||||||
"""
|
"""
|
||||||
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
|
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
#def bertForSequenceClassification(*args, **kwargs):
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
# model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
|
def bertForSequenceClassification(*args, **kwargs):
|
||||||
# return model
|
"""
|
||||||
|
BertForSequenceClassification is a fine-tuning model that includes
|
||||||
|
BertModel and a sequence-level (sequence or pair of sequences) classifier
|
||||||
|
on top of the BertModel.
|
||||||
|
|
||||||
|
The sequence-level classifier is a linear layer that takes as input the
|
||||||
|
last hidden state of the first character in the input sequence
|
||||||
|
(see Figures 3a and 3b in the BERT paper).
|
||||||
|
"""
|
||||||
|
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
#def bertForMultipleChoice(*args, **kwargs):
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
# model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
|
def bertForMultipleChoice(*args, **kwargs):
|
||||||
# return model
|
"""
|
||||||
|
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
|
||||||
|
linear layer on top of the BertModel.
|
||||||
|
"""
|
||||||
|
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
def bertForQuestionAnswering(*args, **kwargs):
|
def bertForQuestionAnswering(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
BertForQuestionAnswering is a fine-tuning model that includes BertModel with
|
BertForQuestionAnswering is a fine-tuning model that includes BertModel
|
||||||
a token-level classifiers on top of the full sequence of last hidden states.
|
with a token-level classifiers on top of the full sequence of last hidden
|
||||||
Params:
|
states.
|
||||||
pretrained_model_name_or_path: either:
|
|
||||||
- a str with the name of a pre-trained model to load selected in the list of:
|
|
||||||
. `bert-base-uncased`
|
|
||||||
. `bert-large-uncased`
|
|
||||||
. `bert-base-cased`
|
|
||||||
. `bert-large-cased`
|
|
||||||
. `bert-base-multilingual-uncased`
|
|
||||||
. `bert-base-multilingual-cased`
|
|
||||||
. `bert-base-chinese`
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
|
||||||
- a path or url to a pretrained model archive containing:
|
|
||||||
. `bert_config.json` a configuration file for the model
|
|
||||||
. `model.chkpt` a TensorFlow checkpoint
|
|
||||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
|
||||||
"""
|
"""
|
||||||
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
|
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@_append_from_pretrained_docstring(bert_docstring)
|
||||||
|
def bertForTokenClassification(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
BertForTokenClassification is a fine-tuning model that includes BertModel
|
||||||
|
and a token-level classifier on top of the BertModel.
|
||||||
|
|
||||||
|
The token-level classifier is a linear layer that takes as input the last
|
||||||
|
hidden state of the sequence.
|
||||||
|
"""
|
||||||
|
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -519,8 +519,7 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||||
from_tf=False, *inputs, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -547,6 +546,13 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
(ex: num_labels for BertForSequenceClassification)
|
||||||
"""
|
"""
|
||||||
|
state_dict = kwargs.get('state_dict', None)
|
||||||
|
kwargs.pop('state_dict', None)
|
||||||
|
cache_dir = kwargs.get('cache_dir', None)
|
||||||
|
kwargs.pop('cache_dir', None)
|
||||||
|
from_tf = kwargs.get('from_tf', False)
|
||||||
|
kwargs.pop('from_tf', None)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user