[AutoProcessor] Correct AutoProcessor and automatically add processor… (#14881)
* [AutoProcessor] Correct AutoProcessor and automatically add processor class * up * up * up * up * up * up * up * up * continue tomorrow * up * up * up * make processor class private * fix loop
This commit is contained in:
committed by
GitHub
parent
d7d60df0ec
commit
a1392883ce
@@ -202,6 +202,8 @@ class FeatureExtractionMixin:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Set elements of `kwargs` as attributes."""
|
||||
# Pop "processor_class" as it should be saved as private attribute
|
||||
self._processor_class = kwargs.pop("processor_class", None)
|
||||
# Additional attributes without default values
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
@@ -210,6 +212,10 @@ class FeatureExtractionMixin:
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
def _set_processor_class(self, processor_class: str):
|
||||
"""Sets processor class as an attribute."""
|
||||
self._processor_class = processor_class
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
@@ -481,6 +487,11 @@ class FeatureExtractionMixin:
|
||||
if isinstance(value, np.ndarray):
|
||||
dictionary[key] = value.tolist()
|
||||
|
||||
# make sure private name "_processor_class" is correctly
|
||||
# saved as "processor_class"
|
||||
if dictionary.get("_processor_class", None) is not None:
|
||||
dictionary["processor_class"] = dictionary.pop("_processor_class")
|
||||
|
||||
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
|
||||
@@ -20,6 +20,7 @@ from collections import OrderedDict
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_list_of_files
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@@ -28,6 +29,7 @@ from .configuration_auto import (
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
from .tokenization_auto import get_tokenizer_config
|
||||
|
||||
|
||||
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
@@ -151,12 +153,20 @@ class AutoProcessor:
|
||||
# strip to file name
|
||||
model_files = [f.split("/")[-1] for f in model_files]
|
||||
|
||||
# Let's start by checking whether the processor class is saved in a feature extractor
|
||||
if FEATURE_EXTRACTOR_NAME in model_files:
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "processor_class" in config_dict:
|
||||
processor_class = processor_class_from_name(config_dict["processor_class"])
|
||||
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# Next, let's check whether the processor class is saved in a tokenizer
|
||||
if TOKENIZER_CONFIG_FILE in model_files:
|
||||
config_dict = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
||||
if "processor_class" in config_dict:
|
||||
processor_class = processor_class_from_name(config_dict["processor_class"])
|
||||
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# Otherwise, load config, if it can be loaded.
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
@@ -64,8 +64,10 @@ class CLIPProcessor:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -75,8 +75,10 @@ class LayoutLMv2Processor:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -76,8 +76,10 @@ class LayoutXLMProcessor:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -69,8 +69,10 @@ class Speech2TextProcessor:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -76,8 +76,10 @@ class VisionTextDualEncoderProcessor:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -72,8 +72,10 @@ class Wav2Vec2Processor:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -116,7 +116,10 @@ class Wav2Vec2ProcessorWithLM:
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
self.feature_extractor._set_processor_class(self.__class__.__name__)
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
|
||||
self.tokenizer._set_processor_class(self.__class__.__name__)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
self.decoder.save_to_dir(save_directory)
|
||||
|
||||
|
||||
@@ -1444,6 +1444,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
self.init_inputs = ()
|
||||
self.init_kwargs = copy.deepcopy(kwargs)
|
||||
self.name_or_path = kwargs.pop("name_or_path", "")
|
||||
self._processor_class = kwargs.pop("processor_class", None)
|
||||
|
||||
# For backward compatibility we fallback to set model_max_length from max_len if provided
|
||||
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
|
||||
@@ -1505,6 +1506,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
"Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
|
||||
)
|
||||
|
||||
def _set_processor_class(self, processor_class: str):
|
||||
"""Sets processor class as an attribute."""
|
||||
self._processor_class = processor_class
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
|
||||
@@ -2029,6 +2034,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
tokenizer_config["tokenizer_class"] = tokenizer_class
|
||||
if getattr(self, "_auto_map", None) is not None:
|
||||
tokenizer_config["auto_map"] = self._auto_map
|
||||
if getattr(self, "_processor_class", None) is not None:
|
||||
tokenizer_config["processor_class"] = self._processor_class
|
||||
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
|
||||
@@ -13,13 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
|
||||
from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
|
||||
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
|
||||
|
||||
SAMPLE_PROCESSOR_CONFIG = os.path.join(
|
||||
@@ -55,3 +57,61 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
def test_processor_from_feat_extr_processor_class(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor()
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
|
||||
|
||||
# save in new folder
|
||||
processor.save_pretrained(tmpdirname)
|
||||
|
||||
# drop `processor_class` in tokenizer
|
||||
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "r") as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict.pop("processor_class")
|
||||
|
||||
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "w") as f:
|
||||
f.write(json.dumps(config_dict))
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
def test_processor_from_tokenizer_processor_class(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor()
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
|
||||
|
||||
# save in new folder
|
||||
processor.save_pretrained(tmpdirname)
|
||||
|
||||
# drop `processor_class` in feature extractor
|
||||
with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "r") as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict.pop("processor_class")
|
||||
|
||||
with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "w") as f:
|
||||
f.write(json.dumps(config_dict))
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
def test_processor_from_local_directory_from_model_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_config = Wav2Vec2Config(processor_class="Wav2Vec2Processor")
|
||||
model_config.save_pretrained(tmpdirname)
|
||||
# copy relevant files
|
||||
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
|
||||
# create emtpy sample processor
|
||||
with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
Reference in New Issue
Block a user