Make PreTrainedModel.from_pretrained pass unused arguments to model

This commit is contained in:
Anish Moorthy
2019-07-22 17:56:27 -04:00
parent 2f869dc665
commit b8009cb0da

View File

@@ -78,7 +78,7 @@ class PretrainedConfig(object):
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
r""" Instantiate a PretrainedConfig from a pre-trained model configuration. r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params: Params:
@@ -105,6 +105,7 @@ class PretrainedConfig(object):
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
return_unused_args = kwargs.pop('return_unused_args', False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
@@ -148,7 +149,10 @@ class PretrainedConfig(object):
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model config %s", config) logger.info("Model config %s", config)
return config if return_unused_args:
return config, kwargs
else:
return config
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
@@ -305,7 +309,7 @@ class PreTrainedModel(nn.Module):
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
@@ -336,9 +340,17 @@ class PreTrainedModel(nn.Module):
configuration should be cached if the standard cache should not be used. configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean: **output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**model_args**: (`optional`) Sequence:
All positional arguments will be passed to the underlying model's __init__ function
**kwargs**: (`optional`) dict: **kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading. Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True`` Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
If config is None, then **kwargs will be passed to the model.
If said key is *not* present, then kwargs will be used to
override any keys shared with the default configuration for the
given pretrained_model_name_or_path, and only the unshared
key/value pairs will be passed to the model.
Examples:: Examples::
@@ -359,7 +371,12 @@ class PreTrainedModel(nn.Module):
# Load config # Load config
if config is None: if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args,
return_unused_args=True, **kwargs
)
else:
model_kwargs = kwargs
# Load model # Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
@@ -400,7 +417,7 @@ class PreTrainedModel(nn.Module):
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
# Instantiate model. # Instantiate model.
model = cls(config) model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
@@ -530,7 +547,7 @@ class PoolerEndLogits(nn.Module):
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
hidden states of the first tokens for the labeled span. hidden states of the first tokens for the labeled span.
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span: position of the first token for the labeled span:
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
Mask of invalid position such as query and special symbols (PAD, SEP, CLS) Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked. 1.0 means token should be masked.
@@ -717,7 +734,7 @@ class SequenceSummary(nn.Module):
- 'attn' => Not implemented now, use multi-head attention - 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation summary_last_dropout: Add a dropout after the projection and activation
""" """