Register ModelOutput as supported torch pytree nodes (#26618)

* Register ModelOutput as supported torch pytree nodes

* Test ModelOutput as supported torch pytree nodes

* Update type hints for pytree unflatten functions
This commit is contained in:
Xuehai Pan
2023-10-24 17:02:40 +08:00
committed by GitHub
parent ede051f1b8
commit cc7803c0a6
2 changed files with 30 additions and 13 deletions

View File

@@ -22,7 +22,7 @@ from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
from typing import Any, ContextManager, List, Tuple from typing import Any, ContextManager, Iterable, List, Tuple
import numpy as np import numpy as np
@@ -306,12 +306,10 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses. `static_graph=True` with modules that output `ModelOutput` subclasses.
""" """
if is_torch_available(): if is_torch_available():
import torch.utils._pytree _torch_pytree._register_pytree_node(
torch.utils._pytree._register_pytree_node(
cls, cls,
torch.utils._pytree._dict_flatten, _model_output_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), _model_output_unflatten,
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -430,6 +428,23 @@ class ModelOutput(OrderedDict):
return tuple(self[k] for k in self.keys()) return tuple(self[k] for k in self.keys())
if is_torch_available():
import torch.utils._pytree as _torch_pytree
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), (type(output), list(output.keys()))
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
)
class ExplicitEnum(str, Enum): class ExplicitEnum(str, Enum):
""" """
Enum with more explicit error message for missing values. Enum with more explicit error message for missing values.

View File

@@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase):
def test_torch_pytree(self): def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True # this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch import torch.utils._pytree as pytree
import torch.utils._pytree
x = ModelOutput({"a": 1.0, "c": 2.0})
self.assertFalse(pytree._is_leaf(x))
x = ModelOutputTest(a=1.0, c=2.0) x = ModelOutputTest(a=1.0, c=2.0)
self.assertFalse(torch.utils._pytree._is_leaf(x)) self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0] expected_flat_outs = [1.0, 2.0]
expected_tree_spec = torch.utils._pytree.TreeSpec( expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()] ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
) )
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs) self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec) self.assertEqual(expected_tree_spec, actual_tree_spec)
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x) self.assertEqual(x, unflattened_x)