Extend typing to path-like objects in PretrainedConfig and PreTrainedModel (#8770)
* update configuration_utils.py typing to allow pathlike objects when sensible * update modeling_utils.py typing to allow pathlike objects when sensible * black * update tokenization_utils_base.py typing to allow pathlike objects when sensible * update tokenization_utils_fast.py typing to allow pathlike objects when sensible * update configuration_auto.py typing to allow pathlike objects when sensible * update configuration_auto.py docstring to allow pathlike objects when sensible * update tokenization_auto.py docstring to allow pathlike objects when sensible * black
This commit is contained in:
committed by
GitHub
parent
a7d46a0609
commit
f9a2a9e32b
@@ -697,13 +697,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (:obj:`str`):
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -741,7 +741,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
logger.info("Model weights saved in {}".format(output_model_file))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
|
||||
@@ -756,7 +756,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str`, `optional`):
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
@@ -772,11 +772,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
arguments ``config`` and ``state_dict``).
|
||||
model_args (sequence of positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
|
||||
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
|
||||
Can be either:
|
||||
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
|
||||
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
|
||||
be automatically loaded when:
|
||||
@@ -794,7 +794,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
weights. In this case though, you should check if using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained` and
|
||||
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -881,6 +881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||
|
||||
Reference in New Issue
Block a user