Add automatic-mask-generation pipeline for Segment Anything Model (SAM) (#22840)
* cleanup * updates * more refactoring * make style * update inits * support other inputs in base * update based on review Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> * Update tests/pipelines/test_pipelines_automatic_mask_generation.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update * fixup * TODO x and y to refactor, _h _w refactored here * update docstring * more nits * style on these * more doc fix * rename variables * update * updates * style * update * fix `_mask_to_rle_pytorch` * styling * fix ask to rle, wrong outputs * add device arg * update * more updates, fix tets * udpate * update docstrings * styling * fixup * add notebook on the docs * update orginal sizes * fix docstring * updat condition on point_per-batch * updates tests * fix CI test * extend is required, append does not work! * fixup * fix CI tests * whit pixels left * address doc comments * fix doc * slow pipeline tests * update auto init * add revision * make fixup * update p!ipoeline tag when calling tests * alphabeitcal order in inits * fix copies * last style nits * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * reformat docstring * more reformat * address most of the comments * Update src/transformers/pipelines/mask_generation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final refactor * Update src/transformers/models/sam/image_processing_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixup and fix slow tests * revert --------- Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-h")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-h")
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user