Don't save processor_config.json if a processor has no extra attribute (#28584)

* not save if empty

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-01-19 09:59:14 +00:00
committed by GitHub
parent 772307be76
commit db9a7e9d3d
3 changed files with 32 additions and 18 deletions

View File

@@ -101,6 +101,12 @@ class AutoFeatureExtractorTest(unittest.TestCase):
# save in new folder
processor.save_pretrained(tmpdirname)
if not os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
# create one manually in order to perform this test's objective
config_dict = {"processor_class": "Wav2Vec2Processor"}
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as fp:
json.dump(config_dict, fp)
# drop `processor_class` in tokenizer config
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "r") as f:
config_dict = json.load(f)
@@ -123,13 +129,14 @@ class AutoFeatureExtractorTest(unittest.TestCase):
# save in new folder
processor.save_pretrained(tmpdirname)
# drop `processor_class` in processor
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "r") as f:
config_dict = json.load(f)
config_dict.pop("processor_class")
if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
# drop `processor_class` in processor
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "r") as f:
config_dict = json.load(f)
config_dict.pop("processor_class")
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f:
f.write(json.dumps(config_dict))
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f:
f.write(json.dumps(config_dict))
# drop `processor_class` in tokenizer
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "r") as f:
@@ -153,13 +160,14 @@ class AutoFeatureExtractorTest(unittest.TestCase):
# save in new folder
processor.save_pretrained(tmpdirname)
# drop `processor_class` in processor
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "r") as f:
config_dict = json.load(f)
config_dict.pop("processor_class")
if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
# drop `processor_class` in processor
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "r") as f:
config_dict = json.load(f)
config_dict.pop("processor_class")
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f:
f.write(json.dumps(config_dict))
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "w") as f:
f.write(json.dumps(config_dict))
# drop `processor_class` in feature extractor
with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "r") as f:

View File

@@ -75,11 +75,12 @@ class ProcessorTesterMixin:
processor_first = self.get_processor()
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = processor_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
processor_second = self.processor_class.from_pretrained(tmpdirname)
saved_files = processor_first.save_pretrained(tmpdirname)
if len(saved_files) > 0:
check_json_file_has_correct_format(saved_files[0])
processor_second = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor_second.to_dict(), processor_first.to_dict())
self.assertEqual(processor_second.to_dict(), processor_first.to_dict())
class MyProcessor(ProcessorMixin):