Make PreTrainedModel.from_pretrained pass unused arguments to model
This commit is contained in:
@@ -78,7 +78,7 @@ class PretrainedConfig(object):
|
|||||||
self.to_json_file(output_config_file)
|
self.to_json_file(output_config_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||||
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
|
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
@@ -105,6 +105,7 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
|
return_unused_args = kwargs.pop('return_unused_args', False)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||||
@@ -148,7 +149,10 @@ class PretrainedConfig(object):
|
|||||||
kwargs.pop(key, None)
|
kwargs.pop(key, None)
|
||||||
|
|
||||||
logger.info("Model config %s", config)
|
logger.info("Model config %s", config)
|
||||||
return config
|
if return_unused_args:
|
||||||
|
return config, kwargs
|
||||||
|
else:
|
||||||
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, json_object):
|
def from_dict(cls, json_object):
|
||||||
@@ -305,7 +309,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||||
|
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
|
||||||
@@ -336,9 +340,17 @@ class PreTrainedModel(nn.Module):
|
|||||||
configuration should be cached if the standard cache should not be used.
|
configuration should be cached if the standard cache should not be used.
|
||||||
**output_loading_info**: (`optional`) boolean:
|
**output_loading_info**: (`optional`) boolean:
|
||||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
**model_args**: (`optional`) Sequence:
|
||||||
|
All positional arguments will be passed to the underlying model's __init__ function
|
||||||
**kwargs**: (`optional`) dict:
|
**kwargs**: (`optional`) dict:
|
||||||
Dictionnary of key, values to update the configuration object after loading.
|
Dictionary of key, values to update the configuration object after loading.
|
||||||
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
|
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
|
||||||
|
|
||||||
|
If config is None, then **kwargs will be passed to the model.
|
||||||
|
If said key is *not* present, then kwargs will be used to
|
||||||
|
override any keys shared with the default configuration for the
|
||||||
|
given pretrained_model_name_or_path, and only the unshared
|
||||||
|
key/value pairs will be passed to the model.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -359,7 +371,12 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
if config is None:
|
if config is None:
|
||||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args,
|
||||||
|
return_unused_args=True, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_kwargs = kwargs
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
@@ -400,7 +417,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
archive_file, resolved_archive_file))
|
archive_file, resolved_archive_file))
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||||
@@ -530,7 +547,7 @@ class PoolerEndLogits(nn.Module):
|
|||||||
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
||||||
hidden states of the first tokens for the labeled span.
|
hidden states of the first tokens for the labeled span.
|
||||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||||
position of the first token for the labeled span:
|
position of the first token for the labeled span:
|
||||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||||
1.0 means token should be masked.
|
1.0 means token should be masked.
|
||||||
@@ -717,7 +734,7 @@ class SequenceSummary(nn.Module):
|
|||||||
- 'attn' => Not implemented now, use multi-head attention
|
- 'attn' => Not implemented now, use multi-head attention
|
||||||
summary_use_proj: Add a projection after the vector extraction
|
summary_use_proj: Add a projection after the vector extraction
|
||||||
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
||||||
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
||||||
summary_first_dropout: Add a dropout before the projection and activation
|
summary_first_dropout: Add a dropout before the projection and activation
|
||||||
summary_last_dropout: Add a dropout after the projection and activation
|
summary_last_dropout: Add a dropout after the projection and activation
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user