From 5e4b69dc12980ce4ee387cb449bfb1169b4f74c3 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 4 Mar 2024 11:51:16 +0100 Subject: [PATCH] Convert SlimSAM checkpoints (#28379) * First commit * Improve conversion script * Convert more checkpoints * Update src/transformers/models/sam/convert_sam_original_to_hf_format.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Rename file * More updates * Update docstring * Update script --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- ...l_to_hf_format.py => convert_sam_to_hf.py} | 138 ++++++++++++------ utils/not_doctested.txt | 2 +- 2 files changed, 92 insertions(+), 48 deletions(-) rename src/transformers/models/sam/{convert_sam_original_to_hf_format.py => convert_sam_to_hf.py} (69%) diff --git a/src/transformers/models/sam/convert_sam_original_to_hf_format.py b/src/transformers/models/sam/convert_sam_to_hf.py similarity index 69% rename from src/transformers/models/sam/convert_sam_original_to_hf_format.py rename to src/transformers/models/sam/convert_sam_to_hf.py index b3cb45b347..be375494f0 100644 --- a/src/transformers/models/sam/convert_sam_original_to_hf_format.py +++ b/src/transformers/models/sam/convert_sam_to_hf.py @@ -14,6 +14,10 @@ # limitations under the License. """ Convert SAM checkpoints from the original repository. + +URL: https://github.com/facebookresearch/segment-anything. + +Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master. """ import argparse import re @@ -33,6 +37,47 @@ from transformers import ( ) +def get_config(model_name): + if "slimsam-50" in model_name: + vision_config = SamVisionConfig( + hidden_size=384, + mlp_dim=1536, + num_hidden_layers=12, + num_attention_heads=12, + global_attn_indexes=[2, 5, 8, 11], + ) + elif "slimsam-77" in model_name: + vision_config = SamVisionConfig( + hidden_size=168, + mlp_dim=696, + num_hidden_layers=12, + num_attention_heads=12, + global_attn_indexes=[2, 5, 8, 11], + ) + elif "sam_vit_b" in model_name: + vision_config = SamVisionConfig() + elif "sam_vit_l" in model_name: + vision_config = SamVisionConfig( + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ) + elif "sam_vit_h" in model_name: + vision_config = SamVisionConfig( + hidden_size=1280, + num_hidden_layers=32, + num_attention_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ) + + config = SamConfig( + vision_config=vision_config, + ) + + return config + + KEYS_TO_MODIFY_MAPPING = { "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", @@ -88,63 +133,47 @@ def replace_keys(state_dict): return model_state_dict -def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_hub_id="ybelkada/segment-anything"): - checkpoint_path = hf_hub_download(model_hub_id, f"checkpoints/{model_name}.pth") - - if "sam_vit_b" in model_name: - config = SamConfig() - elif "sam_vit_l" in model_name: - vision_config = SamVisionConfig( - hidden_size=1024, - num_hidden_layers=24, - num_attention_heads=16, - global_attn_indexes=[5, 11, 17, 23], - ) - - config = SamConfig( - vision_config=vision_config, - ) - elif "sam_vit_h" in model_name: - vision_config = SamVisionConfig( - hidden_size=1280, - num_hidden_layers=32, - num_attention_heads=16, - global_attn_indexes=[7, 15, 23, 31], - ) - - config = SamConfig( - vision_config=vision_config, - ) +def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub): + config = get_config(model_name) state_dict = torch.load(checkpoint_path, map_location="cpu") state_dict = replace_keys(state_dict) image_processor = SamImageProcessor() - processor = SamProcessor(image_processor=image_processor) hf_model = SamModel(config) + hf_model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" hf_model.load_state_dict(state_dict) - hf_model = hf_model.to("cuda") + hf_model = hf_model.to(device) img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") - input_points = [[[400, 650]]] + input_points = [[[500, 375]]] input_labels = [[1]] - inputs = processor(images=np.array(raw_image), return_tensors="pt").to("cuda") + inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device) with torch.no_grad(): output = hf_model(**inputs) scores = output.iou_scores.squeeze() - if model_name == "sam_vit_h_4b8939": - assert scores[-1].item() == 0.579890251159668 - + if model_name == "sam_vit_b_01ec64": inputs = processor( images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" - ).to("cuda") + ).to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + elif model_name == "sam_vit_h_4b8939": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) with torch.no_grad(): output = hf_model(**inputs) @@ -154,7 +183,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h input_boxes = ((75, 275, 1725, 850),) - inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to("cuda") + inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device) with torch.no_grad(): output = hf_model(**inputs) @@ -168,7 +197,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h inputs = processor( images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" - ).to("cuda") + ).to(device) with torch.no_grad(): output = hf_model(**inputs) @@ -176,16 +205,31 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h assert scores[-1].item() == 0.9936047792434692 + if pytorch_dump_folder is not None: + processor.save_pretrained(pytorch_dump_folder) + hf_model.save_pretrained(pytorch_dump_folder) + + if push_to_hub: + repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}" + processor.push_to_hub(repo_id) + hf_model.push_to_hub(repo_id) + if __name__ == "__main__": parser = argparse.ArgumentParser() - choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195"] + choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"] parser.add_argument( "--model_name", default="sam_vit_h_4b8939", choices=choices, type=str, - help="Path to hf config.json of model to convert", + help="Name of the original model to convert", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=False, + help="Path to the original checkpoint", ) parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") parser.add_argument( @@ -193,14 +237,14 @@ if __name__ == "__main__": action="store_true", help="Whether to push the model and processor to the hub after converting", ) - parser.add_argument( - "--model_hub_id", - default="ybelkada/segment-anything", - choices=choices, - type=str, - help="Path to hf config.json of model to convert", - ) args = parser.parse_args() - convert_sam_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.model_hub_id) + if "slimsam" in args.model_name: + checkpoint_path = args.checkpoint_path + if checkpoint_path is None: + raise ValueError("You need to provide a checkpoint path for SlimSAM models.") + else: + checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth") + + convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index daf47b1cb1..3e4c78cd9c 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -784,7 +784,7 @@ src/transformers/models/rwkv/configuration_rwkv.py src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py src/transformers/models/rwkv/modeling_rwkv.py src/transformers/models/sam/configuration_sam.py -src/transformers/models/sam/convert_sam_original_to_hf_format.py +src/transformers/models/sam/convert_sam_to_hf.py src/transformers/models/sam/image_processing_sam.py src/transformers/models/sam/modeling_sam.py src/transformers/models/sam/modeling_tf_sam.py