From b8009cb0dac9698c1999af7121ea196065510905 Mon Sep 17 00:00:00 2001 From: Anish Moorthy Date: Mon, 22 Jul 2019 17:56:27 -0400 Subject: [PATCH] Make PreTrainedModel.from_pretrained pass unused arguments to model --- pytorch_transformers/modeling_utils.py | 35 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 324cdc17c9..a4e1a44c9d 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -78,7 +78,7 @@ class PretrainedConfig(object): self.to_json_file(output_config_file) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): r""" Instantiate a PretrainedConfig from a pre-trained model configuration. Params: @@ -105,6 +105,7 @@ class PretrainedConfig(object): """ cache_dir = kwargs.pop('cache_dir', None) + return_unused_args = kwargs.pop('return_unused_args', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] @@ -148,7 +149,10 @@ class PretrainedConfig(object): kwargs.pop(key, None) logger.info("Model config %s", config) - return config + if return_unused_args: + return config, kwargs + else: + return config @classmethod def from_dict(cls, json_object): @@ -305,7 +309,7 @@ class PreTrainedModel(nn.Module): torch.save(model_to_save.state_dict(), output_model_file) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) @@ -336,9 +340,17 @@ class PreTrainedModel(nn.Module): configuration should be cached if the standard cache should not be used. **output_loading_info**: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + **model_args**: (`optional`) Sequence: + All positional arguments will be passed to the underlying model's __init__ function **kwargs**: (`optional`) dict: - Dictionnary of key, values to update the configuration object after loading. - Can be used to override selected configuration parameters. E.g. ``output_attention=True`` + Dictionary of key, values to update the configuration object after loading. + Can be used to override selected configuration parameters. E.g. ``output_attention=True``. + + If config is None, then **kwargs will be passed to the model. + If said key is *not* present, then kwargs will be used to + override any keys shared with the default configuration for the + given pretrained_model_name_or_path, and only the unshared + key/value pairs will be passed to the model. Examples:: @@ -359,7 +371,12 @@ class PreTrainedModel(nn.Module): # Load config if config is None: - config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + config, model_kwargs = cls.config_class.from_pretrained( + pretrained_model_name_or_path, *model_args, + return_unused_args=True, **kwargs + ) + else: + model_kwargs = kwargs # Load model if pretrained_model_name_or_path in cls.pretrained_model_archive_map: @@ -400,7 +417,7 @@ class PreTrainedModel(nn.Module): archive_file, resolved_archive_file)) # Instantiate model. - model = cls(config) + model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') @@ -530,7 +547,7 @@ class PoolerEndLogits(nn.Module): **start_states**: ``torch.LongTensor`` of shape identical to hidden_states hidden states of the first tokens for the labeled span. **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` - position of the first token for the labeled span: + position of the first token for the labeled span: **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 1.0 means token should be masked. @@ -717,7 +734,7 @@ class SequenceSummary(nn.Module): - 'attn' => Not implemented now, use multi-head attention summary_use_proj: Add a projection after the vector extraction summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. - summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default + summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_first_dropout: Add a dropout before the projection and activation summary_last_dropout: Add a dropout after the projection and activation """