Extend typing to path-like objects in PretrainedConfig and PreTrainedModel (#8770)
* update configuration_utils.py typing to allow pathlike objects when sensible * update modeling_utils.py typing to allow pathlike objects when sensible * black * update tokenization_utils_base.py typing to allow pathlike objects when sensible * update tokenization_utils_fast.py typing to allow pathlike objects when sensible * update configuration_auto.py typing to allow pathlike objects when sensible * update configuration_auto.py docstring to allow pathlike objects when sensible * update tokenization_auto.py docstring to allow pathlike objects when sensible * black
This commit is contained in:
committed by
GitHub
parent
a7d46a0609
commit
f9a2a9e32b
@@ -19,7 +19,7 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .utils import logging
|
||||
@@ -262,13 +262,13 @@ class PretrainedConfig(object):
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
|
||||
def save_pretrained(self, save_directory: str):
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
|
||||
:func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str`):
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -281,13 +281,13 @@ class PretrainedConfig(object):
|
||||
logger.info("Configuration saved in {}".format(output_config_file))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig":
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
|
||||
@@ -297,7 +297,7 @@ class PretrainedConfig(object):
|
||||
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
|
||||
- a path or url to a saved configuration JSON `file`, e.g.,
|
||||
``./my_model_directory/configuration.json``.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -346,13 +346,15 @@ class PretrainedConfig(object):
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
:class:`~transformers.PretrainedConfig` using ``from_dict``.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
|
||||
Returns:
|
||||
@@ -366,6 +368,7 @@ class PretrainedConfig(object):
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
@@ -451,12 +454,12 @@ class PretrainedConfig(object):
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file: str) -> "PretrainedConfig":
|
||||
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
|
||||
"""
|
||||
Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
|
||||
|
||||
Args:
|
||||
json_file (:obj:`str`):
|
||||
json_file (:obj:`str` or :obj:`os.PathLike`):
|
||||
Path to the JSON file containing the parameters.
|
||||
|
||||
Returns:
|
||||
@@ -467,7 +470,7 @@ class PretrainedConfig(object):
|
||||
return cls(**config_dict)
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: str):
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
@@ -537,12 +540,12 @@ class PretrainedConfig(object):
|
||||
config_dict = self.to_dict()
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: str, use_diff: bool = True):
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (:obj:`str`):
|
||||
json_file_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If set to ``True``, only the difference between the config instance and the default
|
||||
|
||||
Reference in New Issue
Block a user