From e179c55490269432fd9c67fd867f555e81259a34 Mon Sep 17 00:00:00 2001 From: Anish Moorthy Date: Tue, 23 Jul 2019 10:39:51 -0400 Subject: [PATCH] Add docs for from_pretrained functions, rename return_unused_args --- pytorch_transformers/modeling_utils.py | 41 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 0a4bfa7ba0..3e8d2fbb1a 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -91,21 +91,33 @@ class PretrainedConfig(object): **cache_dir**: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. + **return_unused_kwargs**: (`optional`) bool: + - If False, then this function returns just the final configuration object. + - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` + is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: + ie the part of kwargs which has not been used to update `config` and is otherwise ignored. **kwargs**: (`optional`) dict: - Dictionnary of key, values to update the configuration object after loading. - Can be used to override selected configuration parameters. + Dictionary of key/value pairs with which to update the configuration object after loading. + - The values in kwargs of any keys which are configuration attributes will be used + to override the loaded values. + - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled + by the `return_unused_kwargs` keyword parameter. Examples:: >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') - >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True) + >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) >>> assert config.output_attention == True + >>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, + >>> foo=False, return_unused_kwargs=True) + >>> assert config.output_attention == True + >>> assert unused_kwargs == {'foo': False} """ cache_dir = kwargs.pop('cache_dir', None) - return_unused_args = kwargs.pop('return_unused_args', False) + return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] @@ -149,7 +161,7 @@ class PretrainedConfig(object): kwargs.pop(key, None) logger.info("Model config %s", config) - if return_unused_args: + if return_unused_kwargs: return config, kwargs else: return config @@ -326,6 +338,8 @@ class PreTrainedModel(nn.Module): provided as `config` argument. This loading option is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + **model_args**: (`optional`) Sequence: + All remaning positional arguments will be passed to the underlying model's __init__ function **config**: an optional configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or @@ -340,17 +354,18 @@ class PreTrainedModel(nn.Module): configuration should be cached if the standard cache should not be used. **output_loading_info**: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. - **model_args**: (`optional`) Sequence: - All positional arguments will be passed to the underlying model's __init__ function **kwargs**: (`optional`) dict: Dictionary of key, values to update the configuration object after loading. Can be used to override selected configuration parameters. E.g. ``output_attention=True``. - If config is None, then **kwargs will be passed to the model. - If said key is *not* present, then kwargs will be used to - override any keys shared with the default configuration for the - given pretrained_model_name_or_path, and only the unshared - key/value pairs will be passed to the model. + - If a configuration is provided with `config`, **kwargs will be directly passed + to the underlying model's __init__ method. + - If a configuration is not provided, **kwargs will be first passed to the pretrained + model configuration class loading function (`PretrainedConfig.from_pretrained`). + Each key of **kwargs that corresponds to a configuration attribute + will be used to override said attribute with the supplied **kwargs value. + Remaining keys that do not correspond to any configuration attribute will + be passed to the underlying model's __init__ function. Examples:: @@ -373,7 +388,7 @@ class PreTrainedModel(nn.Module): if config is None: config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, - cache_dir=cache_dir, return_unused_args=True, + cache_dir=cache_dir, return_unused_kwargs=True, **kwargs ) else: