Add push_to_hub method to processors (#15668)
* Add push_to_hub method to processors * Fix test * The other one too!
This commit is contained in:
@@ -21,9 +21,13 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
|
from .file_utils import PushToHubMixin, copy_func
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
|
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
|
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
|
||||||
@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProcessorMixin:
|
class ProcessorMixin(PushToHubMixin):
|
||||||
"""
|
"""
|
||||||
This is a mixin used to provide saving/loading functionality for all processor classes.
|
This is a mixin used to provide saving/loading functionality for all processor classes.
|
||||||
"""
|
"""
|
||||||
@@ -88,7 +92,7 @@ class ProcessorMixin:
|
|||||||
attributes_repr = "\n".join(attributes_repr)
|
attributes_repr = "\n".join(attributes_repr)
|
||||||
return f"{self.__class__.__name__}:\n{attributes_repr}"
|
return f"{self.__class__.__name__}:\n{attributes_repr}"
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
|
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
|
||||||
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
|
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
|
||||||
@@ -105,7 +109,24 @@ class ProcessorMixin:
|
|||||||
save_directory (`str` or `os.PathLike`):
|
save_directory (`str` or `os.PathLike`):
|
||||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||||
be created if it does not exist).
|
be created if it does not exist).
|
||||||
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to push your processor to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||||
|
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||||
|
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
kwargs:
|
||||||
|
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||||
# loaded from the Hub.
|
# loaded from the Hub.
|
||||||
@@ -129,6 +150,10 @@ class ProcessorMixin:
|
|||||||
if isinstance(attribute, PreTrainedTokenizerBase):
|
if isinstance(attribute, PreTrainedTokenizerBase):
|
||||||
del attribute.init_kwargs["auto_map"]
|
del attribute.init_kwargs["auto_map"]
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
|
logger.info(f"Processor pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
@@ -205,3 +230,9 @@ class ProcessorMixin:
|
|||||||
|
|
||||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
||||||
|
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
|
||||||
|
object="processor", object_class="AutoProcessor", object_files="processor files"
|
||||||
|
)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
|
|||||||
)
|
)
|
||||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
||||||
|
|
||||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||||
|
|
||||||
|
|
||||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||||
@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, name="test-processor")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
delete_repo(token=cls._token, name="test-dynamic-processor")
|
delete_repo(token=cls._token, name="test-dynamic-processor")
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
processor.save_pretrained(
|
||||||
|
os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
|
||||||
|
for k, v in processor.feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
|
||||||
|
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
processor.save_pretrained(
|
||||||
|
os.path.join(tmp_dir, "test-processor-org"),
|
||||||
|
push_to_hub=True,
|
||||||
|
use_auth_token=self._token,
|
||||||
|
organization="valid_org",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
|
||||||
|
for k, v in processor.feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
|
||||||
|
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
|
||||||
|
|
||||||
def test_push_to_hub_dynamic_processor(self):
|
def test_push_to_hub_dynamic_processor(self):
|
||||||
CustomFeatureExtractor.register_for_auto_class()
|
CustomFeatureExtractor.register_for_auto_class()
|
||||||
CustomTokenizer.register_for_auto_class()
|
CustomTokenizer.register_for_auto_class()
|
||||||
CustomProcessor.register_for_auto_class()
|
CustomProcessor.register_for_auto_class()
|
||||||
|
|
||||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
|||||||
Reference in New Issue
Block a user