Add common test for torch.export and fix some vision models (#35124)
* Add is_torch_greater_or_equal test decorator * Add common test for torch.export * Fix bit * Fix focalnet * Fix imagegpt * Fix seggpt * Fix swin2sr * Enable torch.export test for vision models * Enable test for video models * Remove json * Enable for hiera * Enable for ijepa * Fix detr * Fic conditional_detr * Fix maskformer * Enable test maskformer * Fix test for deformable detr * Fix custom kernels for export in rt-detr and deformable-detr * Enable test for all DPT * Remove custom test for deformable detr * Simplify test to use only kwargs for export * Add comment * Move compile_compatible_method_lru_cache to utils * Fix beit export * Fix deformable detr * Fix copies data2vec<->beit * Fix typos, update test to work with dict * Add seed to the test * Enable test for vit_mae * Fix beit tests * [run-slow] beit, bit, conditional_detr, data2vec, deformable_detr, detr, focalnet, imagegpt, maskformer, rt_detr, seggpt, swin2sr * Add vitpose test * Add textnet test * Add dinov2 with registers * Update tests/test_modeling_common.py * Switch to torch.testing.assert_close * Fix masformer * Remove save-load from test * Add dab_detr * Add depth_pro * Fix and test RT-DETRv2 * Fix dab_detr
This commit is contained in:
committed by
GitHub
parent
1779f5180e
commit
f42d46ccb4
@@ -271,6 +271,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BeitModelTester(self)
|
||||
@@ -292,6 +293,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BEiT can't compile dynamic")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
def test_model_get_set_embeddings(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -764,13 +769,6 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||
# images than what the model supports.
|
||||
self.assertFalse(processor.do_center_crop)
|
||||
with torch.no_grad():
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
model(pixel_values, interpolate_pos_encoding=False)
|
||||
|
||||
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||
# successfully and produce the expected output.
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user