Add serialization logic to pytree types (#27871)

* Add serialized type name to pytrees

* Modify context

* add serde test
This commit is contained in:
Angela Yi
2024-01-29 01:41:20 -08:00
committed by GitHub
parent f1cc615721
commit 243e186efb
3 changed files with 73 additions and 22 deletions

View File

@@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import unittest
from dataclasses import dataclass
from typing import Optional
from transformers import AlbertForMaskedLM
from transformers.testing_utils import require_torch
from transformers.utils import ModelOutput
from transformers.utils import ModelOutput, is_torch_available
if is_torch_available():
import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
@dataclass
@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase):
self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0]
expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
)
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase):
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)
if is_torch_greater_or_equal_than_2_2:
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)
@require_torch
def test_export_serialization(self):
if not is_torch_greater_or_equal_than_2_2:
return
model_cls = AlbertForMaskedLM
model_config = model_cls.config_class()
model = model_cls(model_config)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
ep = torch.export.export(model, (), input_dict)
buffer = io.BytesIO()
torch.export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch.export.load(buffer)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""