Add push_to_hub method to processors (#15668)

* Add push_to_hub method to processors

* Fix test

* The other one too!
This commit is contained in:
Sylvain Gugger
2022-02-15 21:14:04 -05:00
committed by GitHub
parent bee361c6f1
commit 2d02f7b29b
2 changed files with 73 additions and 4 deletions

View File

@@ -21,9 +21,13 @@ import os
from pathlib import Path from pathlib import Path
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .file_utils import PushToHubMixin, copy_func
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .utils import logging
logger = logging.get_logger(__name__)
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names. # Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent] "transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
} }
class ProcessorMixin: class ProcessorMixin(PushToHubMixin):
""" """
This is a mixin used to provide saving/loading functionality for all processor classes. This is a mixin used to provide saving/loading functionality for all processor classes.
""" """
@@ -88,7 +92,7 @@ class ProcessorMixin:
attributes_repr = "\n".join(attributes_repr) attributes_repr = "\n".join(attributes_repr)
return f"{self.__class__.__name__}:\n{attributes_repr}" return f"{self.__class__.__name__}:\n{attributes_repr}"
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
""" """
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
@@ -105,7 +109,24 @@ class ProcessorMixin:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist). be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your processor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
""" """
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub. # loaded from the Hub.
@@ -129,6 +150,10 @@ class ProcessorMixin:
if isinstance(attribute, PreTrainedTokenizerBase): if isinstance(attribute, PreTrainedTokenizerBase):
del attribute.init_kwargs["auto_map"] del attribute.init_kwargs["auto_map"]
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Processor pushed to the hub in this commit: {url}")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" r"""
@@ -205,3 +230,9 @@ class ProcessorMixin:
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args return args
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
object="processor", object_class="AutoProcessor", object_files="processor files"
)

View File

@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
) )
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-processor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, name="test-dynamic-processor") delete_repo(token=cls._token, name="test-dynamic-processor")
except HTTPError: except HTTPError:
pass pass
def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token
)
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_in_organization(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_dynamic_processor(self): def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class() CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class() CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class() CustomProcessor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")