[DETR and friends] Remove is_timm_available (#21814)
* First draft * Fix to_dict * Improve conversion script * Update config * Remove timm dependency * Fix dummies * Fix typo, add integration test * Upload 101 model as well * Remove timm dummies * Fix style --------- Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -20,7 +20,7 @@ import math
|
||||
import unittest
|
||||
|
||||
from transformers import DetrConfig, is_timm_available, is_vision_available
|
||||
from transformers.testing_utils import require_timm, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -510,7 +510,7 @@ def prepare_img():
|
||||
@require_timm
|
||||
@require_vision
|
||||
@slow
|
||||
class DetrModelIntegrationTests(unittest.TestCase):
|
||||
class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") if is_vision_available() else None
|
||||
@@ -626,3 +626,33 @@ class DetrModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4))
|
||||
self.assertTrue(len(results["segments_info"]), expected_number_of_segments)
|
||||
self.assertDictEqual(results["segments_info"][0], expected_first_segment)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@slow
|
||||
class DetrModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return (
|
||||
DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = DetrModel.from_pretrained("facebook/detr-resnet-50", revision="no_timm").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
||||
expected_shape = torch.Size((1, 100, 256))
|
||||
assert outputs.last_hidden_state.shape == expected_shape
|
||||
expected_slice = torch.tensor(
|
||||
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user