Add serialization logic to pytree types (#27871)
* Add serialized type name to pytrees * Modify context * add serde test
This commit is contained in:
@@ -13,12 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AlbertForMaskedLM
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import ModelOutput
|
||||
from transformers.utils import ModelOutput, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase):
|
||||
self.assertFalse(pytree._is_leaf(x))
|
||||
|
||||
expected_flat_outs = [1.0, 2.0]
|
||||
expected_tree_spec = pytree.TreeSpec(
|
||||
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
|
||||
)
|
||||
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
|
||||
|
||||
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||
@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase):
|
||||
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
if is_torch_greater_or_equal_than_2_2:
|
||||
self.assertEqual(
|
||||
pytree.treespec_dumps(actual_tree_spec),
|
||||
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_export_serialization(self):
|
||||
if not is_torch_greater_or_equal_than_2_2:
|
||||
return
|
||||
|
||||
model_cls = AlbertForMaskedLM
|
||||
model_config = model_cls.config_class()
|
||||
model = model_cls(model_config)
|
||||
|
||||
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||
|
||||
ep = torch.export.export(model, (), input_dict)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.export.save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_ep = torch.export.load(buffer)
|
||||
|
||||
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
|
||||
|
||||
|
||||
class ModelOutputTestNoDataclass(ModelOutput):
|
||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||
|
||||
Reference in New Issue
Block a user