Add serialization logic to pytree types (#27871)
* Add serialized type name to pytrees * Modify context * add serde test
This commit is contained in:
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||||
|
|
||||||
|
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
||||||
|
|||||||
@@ -22,11 +22,13 @@ from collections.abc import MutableMapping
|
|||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
from typing import Any, ContextManager, Iterable, List, Tuple
|
from typing import Any, ContextManager, Iterable, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@@ -306,11 +308,19 @@ class ModelOutput(OrderedDict):
|
|||||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||||
"""
|
"""
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
torch_pytree_register_pytree_node(
|
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||||||
cls,
|
_torch_pytree.register_pytree_node(
|
||||||
_model_output_flatten,
|
cls,
|
||||||
_model_output_unflatten,
|
_model_output_flatten,
|
||||||
)
|
partial(_model_output_unflatten, output_type=cls),
|
||||||
|
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_torch_pytree._register_pytree_node(
|
||||||
|
cls,
|
||||||
|
_model_output_flatten,
|
||||||
|
partial(_model_output_unflatten, output_type=cls),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -432,21 +442,28 @@ if is_torch_available():
|
|||||||
import torch.utils._pytree as _torch_pytree
|
import torch.utils._pytree as _torch_pytree
|
||||||
|
|
||||||
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
||||||
return list(output.values()), (type(output), list(output.keys()))
|
return list(output.values()), list(output.keys())
|
||||||
|
|
||||||
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
|
def _model_output_unflatten(
|
||||||
output_type, keys = context
|
values: Iterable[Any],
|
||||||
return output_type(**dict(zip(keys, values)))
|
context: "_torch_pytree.Context",
|
||||||
|
output_type=None,
|
||||||
|
) -> ModelOutput:
|
||||||
|
return output_type(**dict(zip(context, values)))
|
||||||
|
|
||||||
if hasattr(_torch_pytree, "register_pytree_node"):
|
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||||||
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
|
_torch_pytree.register_pytree_node(
|
||||||
|
ModelOutput,
|
||||||
|
_model_output_flatten,
|
||||||
|
partial(_model_output_unflatten, output_type=ModelOutput),
|
||||||
|
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
|
_torch_pytree._register_pytree_node(
|
||||||
torch_pytree_register_pytree_node(
|
ModelOutput,
|
||||||
ModelOutput,
|
_model_output_flatten,
|
||||||
_model_output_flatten,
|
partial(_model_output_unflatten, output_type=ModelOutput),
|
||||||
_model_output_unflatten,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExplicitEnum(str, Enum):
|
class ExplicitEnum(str, Enum):
|
||||||
|
|||||||
@@ -13,12 +13,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import AlbertForMaskedLM
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.testing_utils import require_torch
|
||||||
from transformers.utils import ModelOutput
|
from transformers.utils import ModelOutput, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase):
|
|||||||
self.assertFalse(pytree._is_leaf(x))
|
self.assertFalse(pytree._is_leaf(x))
|
||||||
|
|
||||||
expected_flat_outs = [1.0, 2.0]
|
expected_flat_outs = [1.0, 2.0]
|
||||||
expected_tree_spec = pytree.TreeSpec(
|
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
|
||||||
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
|
|
||||||
)
|
|
||||||
|
|
||||||
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
||||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||||
@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase):
|
|||||||
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||||
self.assertEqual(x, unflattened_x)
|
self.assertEqual(x, unflattened_x)
|
||||||
|
|
||||||
|
if is_torch_greater_or_equal_than_2_2:
|
||||||
|
self.assertEqual(
|
||||||
|
pytree.treespec_dumps(actual_tree_spec),
|
||||||
|
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_export_serialization(self):
|
||||||
|
if not is_torch_greater_or_equal_than_2_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_cls = AlbertForMaskedLM
|
||||||
|
model_config = model_cls.config_class()
|
||||||
|
model = model_cls(model_config)
|
||||||
|
|
||||||
|
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||||
|
|
||||||
|
ep = torch.export.export(model, (), input_dict)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.export.save(ep, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
loaded_ep = torch.export.load(buffer)
|
||||||
|
|
||||||
|
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||||
|
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputTestNoDataclass(ModelOutput):
|
class ModelOutputTestNoDataclass(ModelOutput):
|
||||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||||
|
|||||||
Reference in New Issue
Block a user