Add common test for torch.export and fix some vision models (#35124)
* Add is_torch_greater_or_equal test decorator * Add common test for torch.export * Fix bit * Fix focalnet * Fix imagegpt * Fix seggpt * Fix swin2sr * Enable torch.export test for vision models * Enable test for video models * Remove json * Enable for hiera * Enable for ijepa * Fix detr * Fic conditional_detr * Fix maskformer * Enable test maskformer * Fix test for deformable detr * Fix custom kernels for export in rt-detr and deformable-detr * Enable test for all DPT * Remove custom test for deformable detr * Simplify test to use only kwargs for export * Add comment * Move compile_compatible_method_lru_cache to utils * Fix beit export * Fix deformable detr * Fix copies data2vec<->beit * Fix typos, update test to work with dict * Add seed to the test * Enable test for vit_mae * Fix beit tests * [run-slow] beit, bit, conditional_detr, data2vec, deformable_detr, detr, focalnet, imagegpt, maskformer, rt_detr, seggpt, swin2sr * Add vitpose test * Add textnet test * Add dinov2 with registers * Update tests/test_modeling_common.py * Switch to torch.testing.assert_close * Fix masformer * Remove save-load from test * Add dab_detr * Add depth_pro * Fix and test RT-DETRv2 * Fix dab_detr
This commit is contained in:
committed by
GitHub
parent
1779f5180e
commit
f42d46ccb4
@@ -18,7 +18,7 @@ import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import VitPoseBackboneConfig
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
@@ -27,6 +27,8 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VitPoseBackbone
|
||||
|
||||
|
||||
@@ -129,6 +131,7 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitPoseBackboneModelTester(self)
|
||||
@@ -187,6 +190,17 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_torch_export(self):
|
||||
# Dense architecture
|
||||
super().test_torch_export()
|
||||
|
||||
# MOE architecture
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_experts = 2
|
||||
config.part_features = config.hidden_size // config.num_experts
|
||||
inputs_dict["dataset_index"] = torch.tensor([0] * self.model_tester.batch_size, device=torch_device)
|
||||
super().test_torch_export(config=config, inputs_dict=inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VitPoseBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
||||
Reference in New Issue
Block a user