Register ModelOutput as supported torch pytree nodes (#26618)

* Register ModelOutput as supported torch pytree nodes

* Test ModelOutput as supported torch pytree nodes

* Update type hints for pytree unflatten functions
This commit is contained in:
Xuehai Pan
2023-10-24 17:02:40 +08:00
committed by GitHub
parent ede051f1b8
commit cc7803c0a6
2 changed files with 30 additions and 13 deletions

View File

@@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase):
def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch
import torch.utils._pytree
import torch.utils._pytree as pytree
x = ModelOutput({"a": 1.0, "c": 2.0})
self.assertFalse(pytree._is_leaf(x))
x = ModelOutputTest(a=1.0, c=2.0)
self.assertFalse(torch.utils._pytree._is_leaf(x))
self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0]
expected_tree_spec = torch.utils._pytree.TreeSpec(
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()]
expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
)
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec)
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)