Add serialization logic to pytree types (#27871)
* Add serialized type name to pytrees * Modify context * add serde test
This commit is contained in:
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||
|
||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
||||
|
||||
@@ -22,11 +22,13 @@ from collections.abc import MutableMapping
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, ContextManager, Iterable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
||||
from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -306,11 +308,19 @@ class ModelOutput(OrderedDict):
|
||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||
"""
|
||||
if is_torch_available():
|
||||
torch_pytree_register_pytree_node(
|
||||
cls,
|
||||
_model_output_flatten,
|
||||
_model_output_unflatten,
|
||||
)
|
||||
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||||
_torch_pytree.register_pytree_node(
|
||||
cls,
|
||||
_model_output_flatten,
|
||||
partial(_model_output_unflatten, output_type=cls),
|
||||
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
||||
)
|
||||
else:
|
||||
_torch_pytree._register_pytree_node(
|
||||
cls,
|
||||
_model_output_flatten,
|
||||
partial(_model_output_unflatten, output_type=cls),
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -432,21 +442,28 @@ if is_torch_available():
|
||||
import torch.utils._pytree as _torch_pytree
|
||||
|
||||
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
||||
return list(output.values()), (type(output), list(output.keys()))
|
||||
return list(output.values()), list(output.keys())
|
||||
|
||||
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
|
||||
output_type, keys = context
|
||||
return output_type(**dict(zip(keys, values)))
|
||||
def _model_output_unflatten(
|
||||
values: Iterable[Any],
|
||||
context: "_torch_pytree.Context",
|
||||
output_type=None,
|
||||
) -> ModelOutput:
|
||||
return output_type(**dict(zip(context, values)))
|
||||
|
||||
if hasattr(_torch_pytree, "register_pytree_node"):
|
||||
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
|
||||
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||||
_torch_pytree.register_pytree_node(
|
||||
ModelOutput,
|
||||
_model_output_flatten,
|
||||
partial(_model_output_unflatten, output_type=ModelOutput),
|
||||
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
|
||||
)
|
||||
else:
|
||||
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
|
||||
torch_pytree_register_pytree_node(
|
||||
ModelOutput,
|
||||
_model_output_flatten,
|
||||
_model_output_unflatten,
|
||||
)
|
||||
_torch_pytree._register_pytree_node(
|
||||
ModelOutput,
|
||||
_model_output_flatten,
|
||||
partial(_model_output_unflatten, output_type=ModelOutput),
|
||||
)
|
||||
|
||||
|
||||
class ExplicitEnum(str, Enum):
|
||||
|
||||
@@ -13,12 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AlbertForMaskedLM
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import ModelOutput
|
||||
from transformers.utils import ModelOutput, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase):
|
||||
self.assertFalse(pytree._is_leaf(x))
|
||||
|
||||
expected_flat_outs = [1.0, 2.0]
|
||||
expected_tree_spec = pytree.TreeSpec(
|
||||
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
|
||||
)
|
||||
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
|
||||
|
||||
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||
@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase):
|
||||
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
if is_torch_greater_or_equal_than_2_2:
|
||||
self.assertEqual(
|
||||
pytree.treespec_dumps(actual_tree_spec),
|
||||
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_export_serialization(self):
|
||||
if not is_torch_greater_or_equal_than_2_2:
|
||||
return
|
||||
|
||||
model_cls = AlbertForMaskedLM
|
||||
model_config = model_cls.config_class()
|
||||
model = model_cls(model_config)
|
||||
|
||||
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||
|
||||
ep = torch.export.export(model, (), input_dict)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.export.save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_ep = torch.export.load(buffer)
|
||||
|
||||
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
|
||||
|
||||
|
||||
class ModelOutputTestNoDataclass(ModelOutput):
|
||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||
|
||||
Reference in New Issue
Block a user