Automate check for new pipelines and metadata update (#19029)
* Automate check for new pipelines and metadata update * Add Datasets to quality extra
This commit is contained in:
@@ -85,6 +85,12 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
),
|
||||
(
|
||||
"visual-question-answering",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
),
|
||||
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
|
||||
]
|
||||
|
||||
|
||||
@@ -236,10 +242,35 @@ def update_metadata(token, commit_sha):
|
||||
repo.push_to_hub(commit_message)
|
||||
|
||||
|
||||
def check_pipeline_tags():
|
||||
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
|
||||
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
|
||||
missing = []
|
||||
for key in pipeline_tasks:
|
||||
if key not in in_table:
|
||||
model = pipeline_tasks[key]["pt"]
|
||||
if isinstance(model, (list, tuple)):
|
||||
model = model[0]
|
||||
model = model.__name__
|
||||
if model not in in_table.values():
|
||||
missing.append(key)
|
||||
|
||||
if len(missing) > 0:
|
||||
msg = ", ".join(missing)
|
||||
raise ValueError(
|
||||
"The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside "
|
||||
f"`utils/update_metadata.py`: {msg}. Please add them!"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
|
||||
parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
|
||||
parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.")
|
||||
args = parser.parse_args()
|
||||
|
||||
update_metadata(args.token, args.commit_sha)
|
||||
if args.check_only:
|
||||
check_pipeline_tags()
|
||||
else:
|
||||
update_metadata(args.token, args.commit_sha)
|
||||
|
||||
Reference in New Issue
Block a user