TF port of the Segment Anything Model (SAM) (#22970)

* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matt
2023-05-19 14:14:13 +01:00
committed by GitHub
parent 8aa8513f71
commit 1c460a5273
14 changed files with 2940 additions and 44 deletions

View File

@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self):
pass
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
@slow
def test_model_from_pretrained(self):
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
)
def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
EXPECTED_SCORES = torch.tensor(
[
@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
],
]
)
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")