Update tiny model summary file for recent models (#22637)

* Update tiny model summary file for recent models

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-04-06 22:52:59 +02:00
committed by GitHub
parent ed67286465
commit c7ec71baf5
6 changed files with 222 additions and 44 deletions

View File

@@ -991,6 +991,12 @@ def get_config_overrides(config_class, processors):
# We use `len(tokenizer)` instead of `tokenizer.vocab_size` to avoid potential issues for tokenizers with non-empty
# `added_tokens_encoder`. One example is the `DebertaV2Tokenizer` where the mask token is the extra token.
vocab_size = len(tokenizer)
# The original checkpoint has length `35998`, but it doesn't have ids `30400` and `30514` but instead `35998` and
# `35999`.
if config_class.__name__ == "GPTSanJapaneseConfig":
vocab_size += 2
config_overrides["vocab_size"] = vocab_size
# Used to create a new model tester with `tokenizer.vocab_size` in order to get the (updated) special token ids.
@@ -1329,6 +1335,33 @@ def build_simple_report(results):
return text, failed_text
def update_tiny_model_summary_file(report_path):
with open(os.path.join(report_path, "tiny_model_summary.json")) as fp:
new_data = json.load(fp)
with open("tests/utils/tiny_model_summary.json") as fp:
data = json.load(fp)
for key, value in new_data.items():
if key not in data:
data[key] = value
else:
for attr in ["tokenizer_classes", "processor_classes", "model_classes"]:
# we might get duplication here. We will remove them below when creating `updated_data`.
data[key][attr].extend(value[attr])
new_sha = value.get("sha", None)
if new_sha is not None:
data[key]["sha"] = new_sha
updated_data = {}
for key in sorted(data.keys()):
updated_data[key] = {}
for attr, value in data[key].items():
# deduplication and sort
updated_data[key][attr] = sorted(set(value)) if attr != "sha" else value
with open(os.path.join(report_path, "updated_tiny_model_summary.json"), "w") as fp:
json.dump(updated_data, fp, indent=4, ensure_ascii=False)
def create_tiny_models(
output_path,
all,
@@ -1444,6 +1477,8 @@ def create_tiny_models(
with open(os.path.join(report_path, "simple_failed_report.txt"), "w") as fp:
fp.write(failed_report)
update_tiny_model_summary_file(report_path=os.path.join(output_path, "reports"))
if __name__ == "__main__":
# This has to be `spawn` to avoid hanging forever!

View File

@@ -171,33 +171,6 @@ def get_tiny_model_summary_from_hub(output_path):
json.dump(summary, fp, ensure_ascii=False, indent=4)
def update_tiny_model_summary_file(report_path):
with open(os.path.join(report_path, "tiny_model_summary.json")) as fp:
new_data = json.load(fp)
with open("tests/utils/tiny_model_summary.json") as fp:
data = json.load(fp)
for key, value in new_data.items():
if key not in data:
data[key] = value
else:
for attr in ["tokenizer_classes", "processor_classes", "model_classes"]:
# we might get duplication here. We will remove them below when creating `updated_data`.
data[key][attr].extend(value[attr])
new_sha = value["sha"]
if new_sha is not None:
data[key]["sha"] = new_sha
updated_data = {}
for key in sorted(data.keys()):
updated_data[key] = {}
for attr, value in data[key].items():
# deduplication and sort
updated_data[key][attr] = sorted(set(value)) if attr != "sha" else value
with open(os.path.join(report_path, "updated_tiny_model_summary.json"), "w") as fp:
json.dump(updated_data, fp, indent=4, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
@@ -225,5 +198,3 @@ if __name__ == "__main__":
token=os.environ.get("TOKEN", None),
num_workers=args.num_workers,
)
update_tiny_model_summary_file(report_path=os.path.join(output_path, "reports"))