From 5fb8bb3e1a897bb46a709e51fb393412e9a15ea8 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 8 Jul 2025 11:38:11 +0200 Subject: [PATCH] fix recompiles due to instance key, and deepcopy issues (#39270) * fix recompiles due to instance key, and deepcopy issues * dict --- src/transformers/modeling_utils.py | 2 +- src/transformers/utils/generic.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f4a928922e..37d21f9fdf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2100,7 +2100,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) self._no_split_modules = self._no_split_modules or [] - _CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only + _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only def post_init(self): """ diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 5326d48d74..8b6afb72ed 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -27,7 +27,6 @@ from dataclasses import dataclass, fields, is_dataclass from enum import Enum from functools import partial, wraps from typing import Any, Callable, ContextManager, Optional, TypedDict -from weakref import WeakKeyDictionary import numpy as np from packaging import version @@ -44,7 +43,7 @@ from .import_utils import ( ) -_CAN_RECORD_REGISTRY = WeakKeyDictionary() +_CAN_RECORD_REGISTRY = {} logger = logging.get_logger(__name__) @@ -1008,7 +1007,7 @@ def check_model_inputs(func): for k, v in all_args["kwargs"].items(): all_args[k] = v - capture_flags = _CAN_RECORD_REGISTRY[self] or [] # there is a weak ref for executorch + capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) # there is a weak ref for executorch recordable_keys = { f"output_{k}": all_args.get( f"output_{k}",