From 70f88eecccb54e344bd8ada1698b4e62ca7d79ff Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 8 Jun 2021 11:22:31 +0200 Subject: [PATCH] Fix tapas issue (#12063) * Fix scatter function to be compatible with torch-scatter 2.7.0 * Allow test again --- src/transformers/models/tapas/modeling_tapas.py | 4 ++-- tests/test_modeling_tapas.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index fb49cb9b2d..11d9c07d9f 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -1697,9 +1697,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name): segment_means = scatter( src=flat_values, - index=flat_index.indices.type(torch.long), + index=flat_index.indices.long(), dim=0, - dim_size=flat_index.num_segments, + dim_size=int(flat_index.num_segments), reduce=segment_reduce_fn, ) diff --git a/tests/test_modeling_tapas.py b/tests/test_modeling_tapas.py index 02c4393b00..40bdba0e70 100644 --- a/tests/test_modeling_tapas.py +++ b/tests/test_modeling_tapas.py @@ -1044,7 +1044,6 @@ class TapasUtilitiesTest(unittest.TestCase): # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(maximum.numpy(), [2, 3]) - @unittest.skip("Fix me I'm failing on CI") def test_reduce_sum_vectorized(self): values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]) index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0)