Add push_to_hub method to processors (#15668)
* Add push_to_hub method to processors * Fix test * The other one too!
This commit is contained in:
@@ -21,9 +21,13 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .file_utils import PushToHubMixin, copy_func
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
|
||||
@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class ProcessorMixin:
|
||||
class ProcessorMixin(PushToHubMixin):
|
||||
"""
|
||||
This is a mixin used to provide saving/loading functionality for all processor classes.
|
||||
"""
|
||||
@@ -88,7 +92,7 @@ class ProcessorMixin:
|
||||
attributes_repr = "\n".join(attributes_repr)
|
||||
return f"{self.__class__.__name__}:\n{attributes_repr}"
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
|
||||
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
|
||||
@@ -105,7 +109,24 @@ class ProcessorMixin:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your processor to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
@@ -129,6 +150,10 @@ class ProcessorMixin:
|
||||
if isinstance(attribute, PreTrainedTokenizerBase):
|
||||
del attribute.init_kwargs["auto_map"]
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Processor pushed to the hub in this commit: {url}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r"""
|
||||
@@ -205,3 +230,9 @@ class ProcessorMixin:
|
||||
|
||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||
return args
|
||||
|
||||
|
||||
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
||||
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
|
||||
object="processor", object_class="AutoProcessor", object_files="processor files"
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
|
||||
)
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
|
||||
|
||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-processor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-processor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
processor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
|
||||
for k, v in processor.feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
|
||||
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
processor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-processor-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
)
|
||||
|
||||
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
|
||||
for k, v in processor.feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
|
||||
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
|
||||
|
||||
def test_push_to_hub_dynamic_processor(self):
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
CustomTokenizer.register_for_auto_class()
|
||||
CustomProcessor.register_for_auto_class()
|
||||
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
|
||||
Reference in New Issue
Block a user