fix recompiles due to instance key, and deepcopy issues (#39270)
* fix recompiles due to instance key, and deepcopy issues * dict
This commit is contained in:
@@ -2100,7 +2100,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||||
|
|
||||||
self._no_split_modules = self._no_split_modules or []
|
self._no_split_modules = self._no_split_modules or []
|
||||||
_CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only
|
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from dataclasses import dataclass, fields, is_dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Callable, ContextManager, Optional, TypedDict
|
from typing import Any, Callable, ContextManager, Optional, TypedDict
|
||||||
from weakref import WeakKeyDictionary
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -44,7 +43,7 @@ from .import_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_CAN_RECORD_REGISTRY = WeakKeyDictionary()
|
_CAN_RECORD_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -1008,7 +1007,7 @@ def check_model_inputs(func):
|
|||||||
for k, v in all_args["kwargs"].items():
|
for k, v in all_args["kwargs"].items():
|
||||||
all_args[k] = v
|
all_args[k] = v
|
||||||
|
|
||||||
capture_flags = _CAN_RECORD_REGISTRY[self] or [] # there is a weak ref for executorch
|
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) # there is a weak ref for executorch
|
||||||
recordable_keys = {
|
recordable_keys = {
|
||||||
f"output_{k}": all_args.get(
|
f"output_{k}": all_args.get(
|
||||||
f"output_{k}",
|
f"output_{k}",
|
||||||
|
|||||||
Reference in New Issue
Block a user