Register ModelOutput as supported torch pytree nodes (#26618)
* Register ModelOutput as supported torch pytree nodes * Test ModelOutput as supported torch pytree nodes * Update type hints for pytree unflatten functions
This commit is contained in:
@@ -22,7 +22,7 @@ from collections.abc import MutableMapping
|
|||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, ContextManager, List, Tuple
|
from typing import Any, ContextManager, Iterable, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -306,12 +306,10 @@ class ModelOutput(OrderedDict):
|
|||||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||||
"""
|
"""
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch.utils._pytree
|
_torch_pytree._register_pytree_node(
|
||||||
|
|
||||||
torch.utils._pytree._register_pytree_node(
|
|
||||||
cls,
|
cls,
|
||||||
torch.utils._pytree._dict_flatten,
|
_model_output_flatten,
|
||||||
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
_model_output_unflatten,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -430,6 +428,23 @@ class ModelOutput(OrderedDict):
|
|||||||
return tuple(self[k] for k in self.keys())
|
return tuple(self[k] for k in self.keys())
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch.utils._pytree as _torch_pytree
|
||||||
|
|
||||||
|
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
||||||
|
return list(output.values()), (type(output), list(output.keys()))
|
||||||
|
|
||||||
|
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
|
||||||
|
output_type, keys = context
|
||||||
|
return output_type(**dict(zip(keys, values)))
|
||||||
|
|
||||||
|
_torch_pytree._register_pytree_node(
|
||||||
|
ModelOutput,
|
||||||
|
_model_output_flatten,
|
||||||
|
_model_output_unflatten,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExplicitEnum(str, Enum):
|
class ExplicitEnum(str, Enum):
|
||||||
"""
|
"""
|
||||||
Enum with more explicit error message for missing values.
|
Enum with more explicit error message for missing values.
|
||||||
|
|||||||
@@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase):
|
|||||||
def test_torch_pytree(self):
|
def test_torch_pytree(self):
|
||||||
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
||||||
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
||||||
import torch
|
import torch.utils._pytree as pytree
|
||||||
import torch.utils._pytree
|
|
||||||
|
x = ModelOutput({"a": 1.0, "c": 2.0})
|
||||||
|
self.assertFalse(pytree._is_leaf(x))
|
||||||
|
|
||||||
x = ModelOutputTest(a=1.0, c=2.0)
|
x = ModelOutputTest(a=1.0, c=2.0)
|
||||||
self.assertFalse(torch.utils._pytree._is_leaf(x))
|
self.assertFalse(pytree._is_leaf(x))
|
||||||
|
|
||||||
expected_flat_outs = [1.0, 2.0]
|
expected_flat_outs = [1.0, 2.0]
|
||||||
expected_tree_spec = torch.utils._pytree.TreeSpec(
|
expected_tree_spec = pytree.TreeSpec(
|
||||||
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()]
|
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
|
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
||||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||||
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
||||||
|
|
||||||
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||||
self.assertEqual(x, unflattened_x)
|
self.assertEqual(x, unflattened_x)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user