fix recompiles due to instance key, and deepcopy issues (#39270)
* fix recompiles due to instance key, and deepcopy issues * dict
This commit is contained in:
@@ -2100,7 +2100,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
_CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only
|
||||
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,6 @@ from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, ContextManager, Optional, TypedDict
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@@ -44,7 +43,7 @@ from .import_utils import (
|
||||
)
|
||||
|
||||
|
||||
_CAN_RECORD_REGISTRY = WeakKeyDictionary()
|
||||
_CAN_RECORD_REGISTRY = {}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -1008,7 +1007,7 @@ def check_model_inputs(func):
|
||||
for k, v in all_args["kwargs"].items():
|
||||
all_args[k] = v
|
||||
|
||||
capture_flags = _CAN_RECORD_REGISTRY[self] or [] # there is a weak ref for executorch
|
||||
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) # there is a weak ref for executorch
|
||||
recordable_keys = {
|
||||
f"output_{k}": all_args.get(
|
||||
f"output_{k}",
|
||||
|
||||
Reference in New Issue
Block a user