fix recompiles due to instance key, and deepcopy issues (#39270)

* fix recompiles due to instance key, and deepcopy issues

* dict
This commit is contained in:
Arthur
2025-07-08 11:38:11 +02:00
committed by GitHub
parent 356fd68109
commit 5fb8bb3e1a
2 changed files with 3 additions and 4 deletions

View File

@@ -2100,7 +2100,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self._no_split_modules = self._no_split_modules or [] self._no_split_modules = self._no_split_modules or []
_CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
def post_init(self): def post_init(self):
""" """

View File

@@ -27,7 +27,6 @@ from dataclasses import dataclass, fields, is_dataclass
from enum import Enum from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, ContextManager, Optional, TypedDict from typing import Any, Callable, ContextManager, Optional, TypedDict
from weakref import WeakKeyDictionary
import numpy as np import numpy as np
from packaging import version from packaging import version
@@ -44,7 +43,7 @@ from .import_utils import (
) )
_CAN_RECORD_REGISTRY = WeakKeyDictionary() _CAN_RECORD_REGISTRY = {}
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -1008,7 +1007,7 @@ def check_model_inputs(func):
for k, v in all_args["kwargs"].items(): for k, v in all_args["kwargs"].items():
all_args[k] = v all_args[k] = v
capture_flags = _CAN_RECORD_REGISTRY[self] or [] # there is a weak ref for executorch capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) # there is a weak ref for executorch
recordable_keys = { recordable_keys = {
f"output_{k}": all_args.get( f"output_{k}": all_args.get(
f"output_{k}", f"output_{k}",