From d4bd33cc9f11ca48635e54983d75249c78d72e2a Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Mon, 7 Aug 2023 23:12:11 -0700 Subject: [PATCH] Register ModelOutput subclasses as supported torch.utils._pytree nodes (#25358) * Register ModelOutput subclasses as supported torch.utils._pytree nodes Fixes #25357 where DDP with static_graph=True does not sync gradients when calling backward() over tensors contained in ModelOutput subclasses * Add test for torch pytree ModelOutput serialization and deserialization --- src/transformers/utils/generic.py | 15 +++++++++++++++ tests/utils/test_model_output.py | 23 +++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index afe1024083..500b6192ab 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -248,6 +248,21 @@ class ModelOutput(OrderedDict): """ + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + if is_torch_available(): + import torch.utils._pytree + + torch.utils._pytree._register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + def __post_init__(self): class_fields = fields(self) diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index 20ff5ceba8..b415b6c2ef 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -17,6 +17,7 @@ import unittest from dataclasses import dataclass from typing import Optional +from transformers.testing_utils import require_torch from transformers.utils import ModelOutput @@ -120,3 +121,25 @@ class ModelOutputTester(unittest.TestCase): x = ModelOutputTest(a=(30, 30)) self.assertEqual(list(x.keys()), ["a"]) self.assertEqual(x.a, (30, 30)) + + @require_torch + def test_torch_pytree(self): + # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) + # this is important for DistributedDataParallel gradient synchronization with static_graph=True + import torch + import torch.utils._pytree + + x = ModelOutputTest(a=1.0, c=2.0) + self.assertFalse(torch.utils._pytree._is_leaf(x)) + + expected_flat_outs = [1.0, 2.0] + expected_tree_spec = torch.utils._pytree.TreeSpec( + ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()] + ) + + actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) + self.assertEqual(expected_flat_outs, actual_flat_outs) + self.assertEqual(expected_tree_spec, actual_tree_spec) + + unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) + self.assertEqual(x, unflattened_x)