Extend typing to path-like objects in PretrainedConfig and PreTrainedModel (#8770)
* update configuration_utils.py typing to allow pathlike objects when sensible * update modeling_utils.py typing to allow pathlike objects when sensible * black * update tokenization_utils_base.py typing to allow pathlike objects when sensible * update tokenization_utils_fast.py typing to allow pathlike objects when sensible * update configuration_auto.py typing to allow pathlike objects when sensible * update configuration_auto.py docstring to allow pathlike objects when sensible * update tokenization_auto.py docstring to allow pathlike objects when sensible * black
This commit is contained in:
committed by
GitHub
parent
a7d46a0609
commit
f9a2a9e32b
@@ -19,7 +19,7 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .utils import logging
|
||||
@@ -262,13 +262,13 @@ class PretrainedConfig(object):
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
|
||||
def save_pretrained(self, save_directory: str):
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
|
||||
:func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str`):
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -281,13 +281,13 @@ class PretrainedConfig(object):
|
||||
logger.info("Configuration saved in {}".format(output_config_file))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig":
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
|
||||
@@ -297,7 +297,7 @@ class PretrainedConfig(object):
|
||||
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
|
||||
- a path or url to a saved configuration JSON `file`, e.g.,
|
||||
``./my_model_directory/configuration.json``.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -346,13 +346,15 @@ class PretrainedConfig(object):
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
:class:`~transformers.PretrainedConfig` using ``from_dict``.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
|
||||
Returns:
|
||||
@@ -366,6 +368,7 @@ class PretrainedConfig(object):
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
@@ -451,12 +454,12 @@ class PretrainedConfig(object):
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file: str) -> "PretrainedConfig":
|
||||
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
|
||||
"""
|
||||
Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
|
||||
|
||||
Args:
|
||||
json_file (:obj:`str`):
|
||||
json_file (:obj:`str` or :obj:`os.PathLike`):
|
||||
Path to the JSON file containing the parameters.
|
||||
|
||||
Returns:
|
||||
@@ -467,7 +470,7 @@ class PretrainedConfig(object):
|
||||
return cls(**config_dict)
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: str):
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
@@ -537,12 +540,12 @@ class PretrainedConfig(object):
|
||||
config_dict = self.to_dict()
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: str, use_diff: bool = True):
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (:obj:`str`):
|
||||
json_file_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If set to ``True``, only the difference between the config instance and the default
|
||||
|
||||
@@ -697,13 +697,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (:obj:`str`):
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -741,7 +741,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
logger.info("Model weights saved in {}".format(output_model_file))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
|
||||
@@ -756,7 +756,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str`, `optional`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
@@ -772,11 +772,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
arguments ``config`` and ``state_dict``).
|
||||
model_args (sequence of positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
|
||||
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
|
||||
Can be either:
|
||||
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
|
||||
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
|
||||
be automatically loaded when:
|
||||
@@ -794,7 +794,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
weights. In this case though, you should check if using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained` and
|
||||
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -881,6 +881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||
|
||||
@@ -274,7 +274,7 @@ class AutoConfig:
|
||||
List options
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model configuration hosted inside a model repo on
|
||||
@@ -285,7 +285,7 @@ class AutoConfig:
|
||||
:meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a saved configuration JSON `file`, e.g.,
|
||||
``./my_model_directory/configuration.json``.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -346,7 +346,7 @@ class AutoConfig:
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
for pattern, config_class in CONFIG_MAPPING.items():
|
||||
if pattern in pretrained_model_name_or_path:
|
||||
if pattern in str(pretrained_model_name_or_path):
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -502,7 +502,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
deactivated). To train the model, you should first set it back in training mode with ``model.train()``
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
@@ -533,7 +533,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
weights. In this case though, you should check if using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained` and
|
||||
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
|
||||
@@ -267,7 +267,7 @@ class AutoTokenizer:
|
||||
List options
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
|
||||
@@ -283,7 +283,7 @@ class AutoTokenizer:
|
||||
Will be passed along to the Tokenizer ``__init__()`` method.
|
||||
config (:class:`~transformers.PreTrainedConfig`, `optional`)
|
||||
The configuration object used to dertermine the tokenizer class to instantiate.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
|
||||
@@ -1608,13 +1608,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from
|
||||
a predefined tokenizer.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
|
||||
@@ -1626,7 +1626,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
- (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
|
||||
file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
|
||||
``./my_model_directory/vocab.txt``.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -1683,6 +1683,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
vocab_files = {}
|
||||
init_configuration = {}
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
@@ -1904,7 +1905,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
return tokenizer
|
||||
|
||||
def save_pretrained(
|
||||
self, save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
legacy_format: bool = True,
|
||||
filename_prefix: Optional[str] = None,
|
||||
) -> Tuple[str]:
|
||||
"""
|
||||
Save the full tokenizer state.
|
||||
@@ -1924,7 +1928,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
modifying :obj:`tokenizer.do_lower_case` after creation).
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str`): The path to a directory where the tokenizer will be saved.
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved.
|
||||
legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a
|
||||
separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only
|
||||
@@ -1988,7 +1992,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
|
||||
def _save_pretrained(
|
||||
self,
|
||||
save_directory: str,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
file_names: Tuple[str],
|
||||
legacy_format: bool = True,
|
||||
filename_prefix: Optional[str] = None,
|
||||
|
||||
@@ -498,7 +498,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
|
||||
def _save_pretrained(
|
||||
self,
|
||||
save_directory: str,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
file_names: Tuple[str],
|
||||
legacy_format: bool = True,
|
||||
filename_prefix: Optional[str] = None,
|
||||
|
||||
Reference in New Issue
Block a user