From cc7803c0a6194ce795ab903979dde9216c82f5bc Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 24 Oct 2023 17:02:40 +0800 Subject: [PATCH] Register ModelOutput as supported torch pytree nodes (#26618) * Register ModelOutput as supported torch pytree nodes * Test ModelOutput as supported torch pytree nodes * Update type hints for pytree unflatten functions --- src/transformers/utils/generic.py | 27 +++++++++++++++++++++------ tests/utils/test_model_output.py | 16 +++++++++------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index dc9ca4b51d..34dac8bea7 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -22,7 +22,7 @@ from collections.abc import MutableMapping from contextlib import ExitStack, contextmanager from dataclasses import fields, is_dataclass from enum import Enum -from typing import Any, ContextManager, List, Tuple +from typing import Any, ContextManager, Iterable, List, Tuple import numpy as np @@ -306,12 +306,10 @@ class ModelOutput(OrderedDict): `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): - import torch.utils._pytree - - torch.utils._pytree._register_pytree_node( + _torch_pytree._register_pytree_node( cls, - torch.utils._pytree._dict_flatten, - lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + _model_output_flatten, + _model_output_unflatten, ) def __init__(self, *args, **kwargs): @@ -430,6 +428,23 @@ class ModelOutput(OrderedDict): return tuple(self[k] for k in self.keys()) +if is_torch_available(): + import torch.utils._pytree as _torch_pytree + + def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]: + return list(output.values()), (type(output), list(output.keys())) + + def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput: + output_type, keys = context + return output_type(**dict(zip(keys, values))) + + _torch_pytree._register_pytree_node( + ModelOutput, + _model_output_flatten, + _model_output_unflatten, + ) + + class ExplicitEnum(str, Enum): """ Enum with more explicit error message for missing values. diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index abfc5427cf..cabdc5fc2d 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase): def test_torch_pytree(self): # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) # this is important for DistributedDataParallel gradient synchronization with static_graph=True - import torch - import torch.utils._pytree + import torch.utils._pytree as pytree + + x = ModelOutput({"a": 1.0, "c": 2.0}) + self.assertFalse(pytree._is_leaf(x)) x = ModelOutputTest(a=1.0, c=2.0) - self.assertFalse(torch.utils._pytree._is_leaf(x)) + self.assertFalse(pytree._is_leaf(x)) expected_flat_outs = [1.0, 2.0] - expected_tree_spec = torch.utils._pytree.TreeSpec( - ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()] + expected_tree_spec = pytree.TreeSpec( + ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()] ) - actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) + actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x) self.assertEqual(expected_flat_outs, actual_flat_outs) self.assertEqual(expected_tree_spec, actual_tree_spec) - unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) + unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) self.assertEqual(x, unflattened_x)