Extend typing to path-like objects in PretrainedConfig and PreTrainedModel (#8770)

* update configuration_utils.py typing to allow pathlike objects when sensible

* update modeling_utils.py typing to allow pathlike objects when sensible

* black

* update tokenization_utils_base.py typing to allow pathlike objects when sensible

* update tokenization_utils_fast.py typing to allow pathlike objects when sensible

* update configuration_auto.py typing to allow pathlike objects when sensible

* update configuration_auto.py docstring to allow pathlike objects when sensible

* update tokenization_auto.py docstring to allow pathlike objects when sensible

* black
This commit is contained in:
Giovanni Compagnoni
2020-11-27 16:52:58 +01:00
committed by GitHub
parent a7d46a0609
commit f9a2a9e32b
7 changed files with 42 additions and 34 deletions

View File

@@ -697,13 +697,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str`):
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
@@ -741,7 +741,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
logger.info("Model weights saved in {}".format(output_model_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
@@ -756,7 +756,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
weights are discarded.
Parameters:
pretrained_model_name_or_path (:obj:`str`, `optional`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
@@ -772,11 +772,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
arguments ``config`` and ``state_dict``).
model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
Can be either:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
be automatically loaded when:
@@ -794,7 +794,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -881,6 +881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint in priority if from_tf