Extend typing to path-like objects in PretrainedConfig and PreTrainedModel (#8770)

* update configuration_utils.py typing to allow pathlike objects when sensible

* update modeling_utils.py typing to allow pathlike objects when sensible

* black

* update tokenization_utils_base.py typing to allow pathlike objects when sensible

* update tokenization_utils_fast.py typing to allow pathlike objects when sensible

* update configuration_auto.py typing to allow pathlike objects when sensible

* update configuration_auto.py docstring to allow pathlike objects when sensible

* update tokenization_auto.py docstring to allow pathlike objects when sensible

* black
This commit is contained in:
Giovanni Compagnoni
2020-11-27 16:52:58 +01:00
committed by GitHub
parent a7d46a0609
commit f9a2a9e32b
7 changed files with 42 additions and 34 deletions

View File

@@ -1608,13 +1608,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
r"""
Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from
a predefined tokenizer.
Args:
pretrained_model_name_or_path (:obj:`str`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
@@ -1626,7 +1626,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
- (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
``./my_model_directory/vocab.txt``.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -1683,6 +1683,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
subfolder = kwargs.pop("subfolder", None)
s3_models = list(cls.max_model_input_sizes.keys())
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
vocab_files = {}
init_configuration = {}
if pretrained_model_name_or_path in s3_models:
@@ -1904,7 +1905,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return tokenizer
def save_pretrained(
self, save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None
self,
save_directory: Union[str, os.PathLike],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
) -> Tuple[str]:
"""
Save the full tokenizer state.
@@ -1924,7 +1928,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
modifying :obj:`tokenizer.do_lower_case` after creation).
Args:
save_directory (:obj:`str`): The path to a directory where the tokenizer will be saved.
save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved.
legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a
separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only
@@ -1988,7 +1992,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def _save_pretrained(
self,
save_directory: str,
save_directory: Union[str, os.PathLike],
file_names: Tuple[str],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,