🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models (#33226)

* Fix `torch.jit.tracing` for `interpolate_pos_encoding` in all vision models

* Apply formatting

* Add missing `self.config = config`

* Fix copies

* Fix hiera interpolation unit test

* Formatting

* Update `_import_structure`

* make style

* Fix docstring

* Use `# Copied from` instead of utils

* DeiT variable renaming (`class_and_dist_pos_embed`)

* Fix Hiera `interpolate_pos_encoding`
This commit is contained in:
Joshua Lochner
2024-09-05 16:17:34 +02:00
committed by GitHub
parent 03164ba14e
commit c6d2848a23
26 changed files with 559 additions and 370 deletions

View File

@@ -578,7 +578,7 @@ class HieraModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[1.8522, 0.1532, 0.3849], [2.7352, -0.1941, 0.1848], [1.5859, -0.0773, 0.0168]]
[[1.7853, 0.0690, 0.3177], [2.6853, -0.2334, 0.0889], [1.5445, -0.1515, -0.0300]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))