187
hubconf.py
Normal file
187
hubconf.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.modeling import (
|
||||
BertModel,
|
||||
BertForNextSentencePrediction,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
)
|
||||
|
||||
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
|
||||
|
||||
# A lot of models share the same param doc. Use a decorator
|
||||
# to save typing
|
||||
bert_docstring = """
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load
|
||||
. `bert-base-uncased`
|
||||
. `bert-large-uncased`
|
||||
. `bert-base-cased`
|
||||
. `bert-large-cased`
|
||||
. `bert-base-multilingual-uncased`
|
||||
. `bert-base-multilingual-cased`
|
||||
. `bert-base-chinese`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining
|
||||
instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
from_tf: should we load the weights from a locally saved TensorFlow
|
||||
checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models
|
||||
will be cached.
|
||||
state_dict: an optional state dictionnary
|
||||
(collections.OrderedDict object) to use instead of Google
|
||||
pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
|
||||
|
||||
def _append_from_pretrained_docstring(docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = fn.__doc__ + docstr
|
||||
return fn
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def bertTokenizer(*args, **kwargs):
|
||||
"""
|
||||
Instantiate a BertTokenizer from a pre-trained/customized vocab file
|
||||
Args:
|
||||
pretrained_model_name_or_path: Path to pretrained model archive
|
||||
or one of pre-trained vocab configs below.
|
||||
* bert-base-uncased
|
||||
* bert-large-uncased
|
||||
* bert-base-cased
|
||||
* bert-large-cased
|
||||
* bert-base-multilingual-uncased
|
||||
* bert-base-multilingual-cased
|
||||
* bert-base-chinese
|
||||
Keyword args:
|
||||
cache_dir: an optional path to a specific directory to download and cache
|
||||
the pre-trained model weights.
|
||||
Default: None
|
||||
do_lower_case: Whether to lower case the input.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
Default: True
|
||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||
Default: True
|
||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
||||
Effective maximum length is always the minimum of this
|
||||
value (if specified) and the underlying BERT model's
|
||||
sequence length.
|
||||
Default: None
|
||||
never_split: List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
Default: ["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]
|
||||
|
||||
Example:
|
||||
>>> sentence = 'Hello, World!'
|
||||
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
|
||||
>>> toks = tokenizer.tokenize(sentence)
|
||||
['Hello', '##,', 'World', '##!']
|
||||
>>> ids = tokenizer.convert_tokens_to_ids(toks)
|
||||
[8667, 28136, 1291, 28125]
|
||||
"""
|
||||
tokenizer = BertTokenizer.from_pretrained(*args, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertModel(*args, **kwargs):
|
||||
"""
|
||||
BertModel is the basic BERT Transformer model with a layer of summed token,
|
||||
position and sequence embeddings followed by a series of identical
|
||||
self-attention blocks (12 for BERT-base, 24 for BERT-large).
|
||||
"""
|
||||
model = BertModel.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForNextSentencePrediction(*args, **kwargs):
|
||||
"""
|
||||
BERT model with next sentence prediction head.
|
||||
This module comprises the BERT model followed by the next sentence
|
||||
classification head.
|
||||
"""
|
||||
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForPreTraining(*args, **kwargs):
|
||||
"""
|
||||
BERT model with pre-training heads.
|
||||
This module comprises the BERT model followed by the two pre-training heads
|
||||
- the masked language modeling head, and
|
||||
- the next sentence classification head.
|
||||
"""
|
||||
model = BertForPreTraining.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForMaskedLM(*args, **kwargs):
|
||||
"""
|
||||
BertForMaskedLM includes the BertModel Transformer followed by the
|
||||
(possibly) pre-trained masked language modeling head.
|
||||
"""
|
||||
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForSequenceClassification(*args, **kwargs):
|
||||
"""
|
||||
BertForSequenceClassification is a fine-tuning model that includes
|
||||
BertModel and a sequence-level (sequence or pair of sequences) classifier
|
||||
on top of the BertModel.
|
||||
|
||||
The sequence-level classifier is a linear layer that takes as input the
|
||||
last hidden state of the first character in the input sequence
|
||||
(see Figures 3a and 3b in the BERT paper).
|
||||
"""
|
||||
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForMultipleChoice(*args, **kwargs):
|
||||
"""
|
||||
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
|
||||
linear layer on top of the BertModel.
|
||||
"""
|
||||
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForQuestionAnswering(*args, **kwargs):
|
||||
"""
|
||||
BertForQuestionAnswering is a fine-tuning model that includes BertModel
|
||||
with a token-level classifiers on top of the full sequence of last hidden
|
||||
states.
|
||||
"""
|
||||
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@_append_from_pretrained_docstring(bert_docstring)
|
||||
def bertForTokenClassification(*args, **kwargs):
|
||||
"""
|
||||
BertForTokenClassification is a fine-tuning model that includes BertModel
|
||||
and a token-level classifier on top of the BertModel.
|
||||
|
||||
The token-level classifier is a linear layer that takes as input the last
|
||||
hidden state of the sequence.
|
||||
"""
|
||||
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
@@ -523,8 +523,7 @@ class BertPreTrainedModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
||||
from_tf=False, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@@ -551,6 +550,13 @@ class BertPreTrainedModel(nn.Module):
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
state_dict = kwargs.get('state_dict', None)
|
||||
kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.get('cache_dir', None)
|
||||
kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.get('from_tf', False)
|
||||
kwargs.pop('from_tf', None)
|
||||
|
||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user