Add docs for from_pretrained functions, rename return_unused_args
This commit is contained in:
@@ -91,21 +91,33 @@ class PretrainedConfig(object):
|
|||||||
**cache_dir**: (`optional`) string:
|
**cache_dir**: (`optional`) string:
|
||||||
Path to a directory in which a downloaded pre-trained model
|
Path to a directory in which a downloaded pre-trained model
|
||||||
configuration should be cached if the standard cache should not be used.
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
**return_unused_kwargs**: (`optional`) bool:
|
||||||
|
- If False, then this function returns just the final configuration object.
|
||||||
|
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
|
||||||
|
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
|
||||||
|
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
|
||||||
**kwargs**: (`optional`) dict:
|
**kwargs**: (`optional`) dict:
|
||||||
Dictionnary of key, values to update the configuration object after loading.
|
Dictionary of key/value pairs with which to update the configuration object after loading.
|
||||||
Can be used to override selected configuration parameters.
|
- The values in kwargs of any keys which are configuration attributes will be used
|
||||||
|
to override the loaded values.
|
||||||
|
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||||
|
by the `return_unused_kwargs` keyword parameter.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
||||||
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
|
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
|
||||||
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
|
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
||||||
>>> assert config.output_attention == True
|
>>> assert config.output_attention == True
|
||||||
|
>>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
|
||||||
|
>>> foo=False, return_unused_kwargs=True)
|
||||||
|
>>> assert config.output_attention == True
|
||||||
|
>>> assert unused_kwargs == {'foo': False}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
return_unused_args = kwargs.pop('return_unused_args', False)
|
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||||
@@ -149,7 +161,7 @@ class PretrainedConfig(object):
|
|||||||
kwargs.pop(key, None)
|
kwargs.pop(key, None)
|
||||||
|
|
||||||
logger.info("Model config %s", config)
|
logger.info("Model config %s", config)
|
||||||
if return_unused_args:
|
if return_unused_kwargs:
|
||||||
return config, kwargs
|
return config, kwargs
|
||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
@@ -326,6 +338,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
provided as `config` argument. This loading option is slower than converting the TensorFlow
|
provided as `config` argument. This loading option is slower than converting the TensorFlow
|
||||||
checkpoint in a PyTorch model using the provided conversion scripts and loading
|
checkpoint in a PyTorch model using the provided conversion scripts and loading
|
||||||
the PyTorch model afterwards.
|
the PyTorch model afterwards.
|
||||||
|
**model_args**: (`optional`) Sequence:
|
||||||
|
All remaning positional arguments will be passed to the underlying model's __init__ function
|
||||||
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
|
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
|
||||||
Configuration can be automatically loaded when:
|
Configuration can be automatically loaded when:
|
||||||
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
|
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
|
||||||
@@ -340,17 +354,18 @@ class PreTrainedModel(nn.Module):
|
|||||||
configuration should be cached if the standard cache should not be used.
|
configuration should be cached if the standard cache should not be used.
|
||||||
**output_loading_info**: (`optional`) boolean:
|
**output_loading_info**: (`optional`) boolean:
|
||||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
**model_args**: (`optional`) Sequence:
|
|
||||||
All positional arguments will be passed to the underlying model's __init__ function
|
|
||||||
**kwargs**: (`optional`) dict:
|
**kwargs**: (`optional`) dict:
|
||||||
Dictionary of key, values to update the configuration object after loading.
|
Dictionary of key, values to update the configuration object after loading.
|
||||||
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
|
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
|
||||||
|
|
||||||
If config is None, then **kwargs will be passed to the model.
|
- If a configuration is provided with `config`, **kwargs will be directly passed
|
||||||
If said key is *not* present, then kwargs will be used to
|
to the underlying model's __init__ method.
|
||||||
override any keys shared with the default configuration for the
|
- If a configuration is not provided, **kwargs will be first passed to the pretrained
|
||||||
given pretrained_model_name_or_path, and only the unshared
|
model configuration class loading function (`PretrainedConfig.from_pretrained`).
|
||||||
key/value pairs will be passed to the model.
|
Each key of **kwargs that corresponds to a configuration attribute
|
||||||
|
will be used to override said attribute with the supplied **kwargs value.
|
||||||
|
Remaining keys that do not correspond to any configuration attribute will
|
||||||
|
be passed to the underlying model's __init__ function.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -373,7 +388,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
if config is None:
|
if config is None:
|
||||||
config, model_kwargs = cls.config_class.from_pretrained(
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args,
|
pretrained_model_name_or_path, *model_args,
|
||||||
cache_dir=cache_dir, return_unused_args=True,
|
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user