Fix tapas issue (#12063)
* Fix scatter function to be compatible with torch-scatter 2.7.0 * Allow test again
This commit is contained in:
@@ -1697,9 +1697,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
|
|||||||
|
|
||||||
segment_means = scatter(
|
segment_means = scatter(
|
||||||
src=flat_values,
|
src=flat_values,
|
||||||
index=flat_index.indices.type(torch.long),
|
index=flat_index.indices.long(),
|
||||||
dim=0,
|
dim=0,
|
||||||
dim_size=flat_index.num_segments,
|
dim_size=int(flat_index.num_segments),
|
||||||
reduce=segment_reduce_fn,
|
reduce=segment_reduce_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1044,7 +1044,6 @@ class TapasUtilitiesTest(unittest.TestCase):
|
|||||||
# We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
|
# We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
|
||||||
np.testing.assert_array_equal(maximum.numpy(), [2, 3])
|
np.testing.assert_array_equal(maximum.numpy(), [2, 3])
|
||||||
|
|
||||||
@unittest.skip("Fix me I'm failing on CI")
|
|
||||||
def test_reduce_sum_vectorized(self):
|
def test_reduce_sum_vectorized(self):
|
||||||
values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
|
values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
|
||||||
index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0)
|
index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user