From 243e186efbf7fb93328dd6b34927a4e8c8f24395 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Mon, 29 Jan 2024 01:41:20 -0800 Subject: [PATCH] Add serialization logic to pytree types (#27871) * Add serialized type name to pytrees * Modify context * add serde test --- src/transformers/pytorch_utils.py | 1 + src/transformers/utils/generic.py | 53 ++++++++++++++++++++----------- tests/utils/test_model_output.py | 41 +++++++++++++++++++++--- 3 files changed, 73 insertions(+), 22 deletions(-) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index eefd6707d4..993da84d33 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 635bf0f597..d73698d8c9 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -22,11 +22,13 @@ from collections.abc import MutableMapping from contextlib import ExitStack, contextmanager from dataclasses import fields, is_dataclass from enum import Enum +from functools import partial from typing import Any, ContextManager, Iterable, List, Tuple import numpy as np +from packaging import version -from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy +from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy if is_flax_available(): @@ -306,11 +308,19 @@ class ModelOutput(OrderedDict): `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): - torch_pytree_register_pytree_node( - cls, - _model_output_flatten, - _model_output_unflatten, - ) + if version.parse(get_torch_version()) >= version.parse("2.2"): + _torch_pytree.register_pytree_node( + cls, + _model_output_flatten, + partial(_model_output_unflatten, output_type=cls), + serialized_type_name=f"{cls.__module__}.{cls.__name__}", + ) + else: + _torch_pytree._register_pytree_node( + cls, + _model_output_flatten, + partial(_model_output_unflatten, output_type=cls), + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -432,21 +442,28 @@ if is_torch_available(): import torch.utils._pytree as _torch_pytree def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]: - return list(output.values()), (type(output), list(output.keys())) + return list(output.values()), list(output.keys()) - def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput: - output_type, keys = context - return output_type(**dict(zip(keys, values))) + def _model_output_unflatten( + values: Iterable[Any], + context: "_torch_pytree.Context", + output_type=None, + ) -> ModelOutput: + return output_type(**dict(zip(context, values))) - if hasattr(_torch_pytree, "register_pytree_node"): - torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node + if version.parse(get_torch_version()) >= version.parse("2.2"): + _torch_pytree.register_pytree_node( + ModelOutput, + _model_output_flatten, + partial(_model_output_unflatten, output_type=ModelOutput), + serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}", + ) else: - torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node - torch_pytree_register_pytree_node( - ModelOutput, - _model_output_flatten, - _model_output_unflatten, - ) + _torch_pytree._register_pytree_node( + ModelOutput, + _model_output_flatten, + partial(_model_output_unflatten, output_type=ModelOutput), + ) class ExplicitEnum(str, Enum): diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index cabdc5fc2d..33013f222e 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -13,12 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import unittest from dataclasses import dataclass from typing import Optional +from transformers import AlbertForMaskedLM from transformers.testing_utils import require_torch -from transformers.utils import ModelOutput +from transformers.utils import ModelOutput, is_torch_available + + +if is_torch_available(): + import torch + + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 @dataclass @@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase): self.assertFalse(pytree._is_leaf(x)) expected_flat_outs = [1.0, 2.0] - expected_tree_spec = pytree.TreeSpec( - ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()] - ) + expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()]) actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x) self.assertEqual(expected_flat_outs, actual_flat_outs) @@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase): unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) self.assertEqual(x, unflattened_x) + if is_torch_greater_or_equal_than_2_2: + self.assertEqual( + pytree.treespec_dumps(actual_tree_spec), + '[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]', + ) + + @require_torch + def test_export_serialization(self): + if not is_torch_greater_or_equal_than_2_2: + return + + model_cls = AlbertForMaskedLM + model_config = model_cls.config_class() + model = model_cls(model_config) + + input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)} + + ep = torch.export.export(model, (), input_dict) + + buffer = io.BytesIO() + torch.export.save(ep, buffer) + buffer.seek(0) + loaded_ep = torch.export.load(buffer) + + input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)} + assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits) + class ModelOutputTestNoDataclass(ModelOutput): """Invalid test subclass of ModelOutput where @dataclass decorator is not used"""