Update tiny model summary file for recent models (#22637)
* Update tiny model summary file for recent models --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -991,6 +991,12 @@ def get_config_overrides(config_class, processors):
|
||||
# We use `len(tokenizer)` instead of `tokenizer.vocab_size` to avoid potential issues for tokenizers with non-empty
|
||||
# `added_tokens_encoder`. One example is the `DebertaV2Tokenizer` where the mask token is the extra token.
|
||||
vocab_size = len(tokenizer)
|
||||
|
||||
# The original checkpoint has length `35998`, but it doesn't have ids `30400` and `30514` but instead `35998` and
|
||||
# `35999`.
|
||||
if config_class.__name__ == "GPTSanJapaneseConfig":
|
||||
vocab_size += 2
|
||||
|
||||
config_overrides["vocab_size"] = vocab_size
|
||||
|
||||
# Used to create a new model tester with `tokenizer.vocab_size` in order to get the (updated) special token ids.
|
||||
@@ -1329,6 +1335,33 @@ def build_simple_report(results):
|
||||
return text, failed_text
|
||||
|
||||
|
||||
def update_tiny_model_summary_file(report_path):
|
||||
with open(os.path.join(report_path, "tiny_model_summary.json")) as fp:
|
||||
new_data = json.load(fp)
|
||||
with open("tests/utils/tiny_model_summary.json") as fp:
|
||||
data = json.load(fp)
|
||||
for key, value in new_data.items():
|
||||
if key not in data:
|
||||
data[key] = value
|
||||
else:
|
||||
for attr in ["tokenizer_classes", "processor_classes", "model_classes"]:
|
||||
# we might get duplication here. We will remove them below when creating `updated_data`.
|
||||
data[key][attr].extend(value[attr])
|
||||
new_sha = value.get("sha", None)
|
||||
if new_sha is not None:
|
||||
data[key]["sha"] = new_sha
|
||||
|
||||
updated_data = {}
|
||||
for key in sorted(data.keys()):
|
||||
updated_data[key] = {}
|
||||
for attr, value in data[key].items():
|
||||
# deduplication and sort
|
||||
updated_data[key][attr] = sorted(set(value)) if attr != "sha" else value
|
||||
|
||||
with open(os.path.join(report_path, "updated_tiny_model_summary.json"), "w") as fp:
|
||||
json.dump(updated_data, fp, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_tiny_models(
|
||||
output_path,
|
||||
all,
|
||||
@@ -1444,6 +1477,8 @@ def create_tiny_models(
|
||||
with open(os.path.join(report_path, "simple_failed_report.txt"), "w") as fp:
|
||||
fp.write(failed_report)
|
||||
|
||||
update_tiny_model_summary_file(report_path=os.path.join(output_path, "reports"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This has to be `spawn` to avoid hanging forever!
|
||||
|
||||
@@ -171,33 +171,6 @@ def get_tiny_model_summary_from_hub(output_path):
|
||||
json.dump(summary, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def update_tiny_model_summary_file(report_path):
|
||||
with open(os.path.join(report_path, "tiny_model_summary.json")) as fp:
|
||||
new_data = json.load(fp)
|
||||
with open("tests/utils/tiny_model_summary.json") as fp:
|
||||
data = json.load(fp)
|
||||
for key, value in new_data.items():
|
||||
if key not in data:
|
||||
data[key] = value
|
||||
else:
|
||||
for attr in ["tokenizer_classes", "processor_classes", "model_classes"]:
|
||||
# we might get duplication here. We will remove them below when creating `updated_data`.
|
||||
data[key][attr].extend(value[attr])
|
||||
new_sha = value["sha"]
|
||||
if new_sha is not None:
|
||||
data[key]["sha"] = new_sha
|
||||
|
||||
updated_data = {}
|
||||
for key in sorted(data.keys()):
|
||||
updated_data[key] = {}
|
||||
for attr, value in data[key].items():
|
||||
# deduplication and sort
|
||||
updated_data[key][attr] = sorted(set(value)) if attr != "sha" else value
|
||||
|
||||
with open(os.path.join(report_path, "updated_tiny_model_summary.json"), "w") as fp:
|
||||
json.dump(updated_data, fp, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
|
||||
@@ -225,5 +198,3 @@ if __name__ == "__main__":
|
||||
token=os.environ.get("TOKEN", None),
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
update_tiny_model_summary_file(report_path=os.path.join(output_path, "reports"))
|
||||
|
||||
Reference in New Issue
Block a user