[core] support tensor-valued _extra_state values in from_pretrained (#38155)

Support tensor-valued _extra_state values

TransformerEngine uses the pytorch get/set_extra_state API to store FP8
layer config information as bytes Tensor in the _extra_state entry in
the state dict. With recent changes to from_pretrained, this
functionality has broken and loading a model that uses this API doesn't
appear to work. This PR fixes the save/load pretrained functions for
extra state entries that use a pytorch tensor, and adds a (currently
x-failing) test for a dictionary extra state.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
Peter St. John
2025-05-28 07:38:42 -06:00
committed by GitHub
parent badc71b9f6
commit bab40c6838
2 changed files with 94 additions and 3 deletions

View File

@@ -2815,3 +2815,86 @@ class TestTensorSharing(TestCasePlus):
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [{"a", "b"}])
self.assertEqual(identical_names, [])
@require_torch
class TestSaveAndLoadModelWithExtraState(TestCasePlus):
"""
This test checks that a model can be saved and loaded that uses the torch extra state API.
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state.
Currently, only tensor-valued extra_states are supported.
"""
def test_save_and_load_model_with_tensor_extra_state(self):
class MyConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.some_counter = 0
self.linear = torch.nn.Linear(320, 320)
def get_extra_state(self):
return torch.tensor(self.some_counter)
def set_extra_state(self, state):
self.some_counter = state.item()
class MyModel(PreTrainedModel):
config_class = MyConfig
def __init__(self, config: MyConfig):
super().__init__(config)
self.my_layer = MyModule()
def forward(self, hidden_states, attention_mask):
return self.my_layer(hidden_states, attention_mask)
config = MyConfig()
model = MyModel(config)
model.my_layer.some_counter = 42
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = MyModel.from_pretrained(tmpdirname)
self.assertEqual(model.my_layer.some_counter, 42)
@mark.xfail(reason="save and from_pretrained currently only supports tensor extra_state")
def test_save_and_load_model_with_dict_extra_state(self):
class MyConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.some_counter = 0
self.linear = torch.nn.Linear(320, 320)
def get_extra_state(self):
return {"some_counter": self.some_counter}
def set_extra_state(self, state):
self.some_counter = state["some_counter"]
class MyModel(PreTrainedModel):
config_class = MyConfig
def __init__(self, config: MyConfig):
super().__init__(config)
self.my_layer = MyModule()
def forward(self, hidden_states, attention_mask):
return self.my_layer(hidden_states, attention_mask)
config = MyConfig()
model = MyModel(config)
model.my_layer.some_counter = 42
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = MyModel.from_pretrained(tmpdirname)
self.assertEqual(model.my_layer.some_counter, 42)