[core] support tensor-valued _extra_state values in from_pretrained (#38155)
Support tensor-valued _extra_state values TransformerEngine uses the pytorch get/set_extra_state API to store FP8 layer config information as bytes Tensor in the _extra_state entry in the state dict. With recent changes to from_pretrained, this functionality has broken and loading a model that uses this API doesn't appear to work. This PR fixes the save/load pretrained functions for extra state entries that use a pytorch tensor, and adds a (currently x-failing) test for a dictionary extra state. Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
@@ -5577,8 +5577,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
||||
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a
|
||||
leaf of the model.
|
||||
`get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute,
|
||||
it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model.
|
||||
"""
|
||||
try:
|
||||
return self.get_parameter(target)
|
||||
@@ -5588,7 +5588,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return self.get_buffer(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
|
||||
module, param_name = get_module_from_name(self, target)
|
||||
if (
|
||||
param_name == "_extra_state"
|
||||
and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
return module.get_extra_state()
|
||||
|
||||
raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
|
||||
|
||||
|
||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||
|
||||
Reference in New Issue
Block a user