add pretrained loading from state_dict
This commit is contained in:
@@ -445,9 +445,9 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
@@ -461,6 +461,8 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
. `bert_config.json` a configuration file for the model
|
. `bert_config.json` a configuration file for the model
|
||||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||||
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||||
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
(ex: num_labels for BertForSequenceClassification)
|
||||||
"""
|
"""
|
||||||
@@ -502,6 +504,7 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
|
if state_dict is None:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||||
state_dict = torch.load(weights_path)
|
state_dict = torch.load(weights_path)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user