big doc update [WIP]

This commit is contained in:
thomwolf
2019-08-04 12:14:57 +02:00
parent bfbe52ec39
commit 009273dbdd
19 changed files with 189 additions and 60 deletions

View File

@@ -55,11 +55,19 @@ else:
class PretrainedConfig(object):
""" Base class for all configuration classes.
Handle a few common parameters and methods for loading/downloading/saving configurations.
Handle a few common attributes and methods for loading/downloading/saving configurations.
"""
pretrained_config_archive_map = {}
def __init__(self, **kwargs):
r""" The initialization of :class:`~pytorch_transformers.PretrainedConfig` extracts
a few configuration attributes from `**kwargs` which are common to all models:
- `finetuning_task`: string, default `None`. Name of the task used to fine-tune the model (used to convert from original checkpoint)
- `num_labels`: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
- `output_attentions`: boolean, default `False`. Should the model returns attentions weights.
- `output_hidden_states`: string, default `False`. Should the model returns all hidden-states.
- `torchscript`: string, default `False`. Is the model used with Torchscript.
"""
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
@@ -67,7 +75,7 @@ class PretrainedConfig(object):
self.torchscript = kwargs.pop('torchscript', False)
def save_pretrained(self, save_directory):
""" Save a configuration object to a directory, so that it
""" Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
@@ -81,30 +89,34 @@ class PretrainedConfig(object):
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
Parameters:
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a saved configuration `file`.
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing a configuration file saved using the `save_pretrained(save_directory)` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**return_unused_kwargs**: (`optional`) bool:
- If False, then this function returns just the final configuration object.
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
**kwargs**: (`optional`) dict:
Dictionary of key/value pairs with which to update the configuration object after loading.
- The values in kwargs of any keys which are configuration attributes will be used
to override the loaded values.
to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
by the `return_unused_kwargs` keyword parameter.
Examples::
# We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
# derived class: BertConfig
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')