Fix for Neuron (#30259)
This commit is contained in:
committed by
ArthurZucker
parent
9fe3f585bb
commit
bb98e7ce58
@@ -1021,8 +1021,11 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
if attention_mask.dim() == 2:
|
if attention_mask.dim() == 2:
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
elif attention_mask.dim() == 4:
|
elif attention_mask.dim() == 4:
|
||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
|
|||||||
@@ -1007,8 +1007,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
if attention_mask.dim() == 2:
|
if attention_mask.dim() == 2:
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
elif attention_mask.dim() == 4:
|
elif attention_mask.dim() == 4:
|
||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
|
|||||||
@@ -1099,8 +1099,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
if attention_mask.dim() == 2:
|
if attention_mask.dim() == 2:
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
elif attention_mask.dim() == 4:
|
elif attention_mask.dim() == 4:
|
||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
|
|||||||
@@ -1082,8 +1082,11 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
if attention_mask.dim() == 2:
|
if attention_mask.dim() == 2:
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
elif attention_mask.dim() == 4:
|
elif attention_mask.dim() == 4:
|
||||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||||
|
|||||||
@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False):
|
|||||||
if os.environ.get("TORCHELASTIC_RUN_ID"):
|
if os.environ.get("TORCHELASTIC_RUN_ID"):
|
||||||
if is_optimum_neuron_available():
|
if is_optimum_neuron_available():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this "
|
"Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
|
||||||
"will fail otherwise."
|
"will fail otherwise."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform "
|
"Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
|
||||||
"training on AWS Trainium instances. More information here: "
|
"training on AWS Trainium instances. More information here: "
|
||||||
"https://github.com/huggingface/optimum-neuron"
|
"https://github.com/huggingface/optimum-neuron"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,22 +15,28 @@
|
|||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
|
from torch.fx._symbolic_trace import is_fx_tracing
|
||||||
from torch.fx.proxy import ParameterProxy
|
from torch.fx.proxy import ParameterProxy
|
||||||
|
|
||||||
from .. import PretrainedConfig, PreTrainedModel, logging
|
from .. import logging
|
||||||
|
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
|
||||||
|
from ..modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
from ..models.auto import get_values
|
from ..models.auto import get_values
|
||||||
from ..models.auto.modeling_auto import (
|
from ..models.auto.modeling_auto import (
|
||||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import (
|
|||||||
MODEL_MAPPING_NAMES,
|
MODEL_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
|
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||||
from ..utils import (
|
from .import_utils import (
|
||||||
ENV_VARS_TRUE_VALUES,
|
ENV_VARS_TRUE_VALUES,
|
||||||
TORCH_FX_REQUIRED_VERSION,
|
TORCH_FX_REQUIRED_VERSION,
|
||||||
get_torch_version,
|
get_torch_version,
|
||||||
@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [
|
|||||||
]
|
]
|
||||||
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
|
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
|
||||||
|
|
||||||
|
_CURRENT_TRACER = None
|
||||||
|
|
||||||
|
|
||||||
def torch_nn_embedding(self, input):
|
def torch_nn_embedding(self, input):
|
||||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
|
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
|
||||||
@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HFCacheProxy(HFProxy):
|
||||||
|
"""
|
||||||
|
Proxy that represents an instance of `transformers.cache_utils.Cache`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __class__(self):
|
||||||
|
return ProxyableCache
|
||||||
|
|
||||||
|
|
||||||
|
def create_wrapper(
|
||||||
|
function: Callable,
|
||||||
|
op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]],
|
||||||
|
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
||||||
|
) -> Callable:
|
||||||
|
@functools.wraps(function)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not is_fx_tracing():
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
|
found_proxies = []
|
||||||
|
|
||||||
|
def check_proxy(a):
|
||||||
|
if isinstance(a, Proxy):
|
||||||
|
found_proxies.append(a)
|
||||||
|
|
||||||
|
torch.fx.node.map_aggregate(args, check_proxy)
|
||||||
|
torch.fx.node.map_aggregate(kwargs, check_proxy)
|
||||||
|
|
||||||
|
if len(found_proxies) > 0:
|
||||||
|
tracer = found_proxies[0].tracer
|
||||||
|
if op_type == "call_function":
|
||||||
|
target = function
|
||||||
|
elif op_type == "call_method":
|
||||||
|
target = function.__name__
|
||||||
|
elif op_type == "get_attr":
|
||||||
|
target = function.__name__
|
||||||
|
else:
|
||||||
|
raise ValueError(f"op_type {op_type} not supported.")
|
||||||
|
return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn)
|
||||||
|
else:
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class HFProxyableClassMeta(type):
|
||||||
|
"""
|
||||||
|
Metaclass that creates a class with its main methods wrapped to be proxyable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
bases: Tuple[Type, ...],
|
||||||
|
attrs: Dict[str, Any],
|
||||||
|
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
||||||
|
):
|
||||||
|
cls = super().__new__(cls, name, bases, attrs)
|
||||||
|
for attr_name in dir(cls):
|
||||||
|
attr = getattr(cls, attr_name, None)
|
||||||
|
if attr is None:
|
||||||
|
continue
|
||||||
|
if attr_name == "__init__":
|
||||||
|
op_type = "call_function"
|
||||||
|
elif attr_name.startswith("__"):
|
||||||
|
op_type = None
|
||||||
|
elif inspect.ismethod(attr):
|
||||||
|
op_type = "call_function"
|
||||||
|
elif inspect.isfunction(attr):
|
||||||
|
op_type = "call_method"
|
||||||
|
else:
|
||||||
|
op_type = None
|
||||||
|
if op_type is not None:
|
||||||
|
setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn))
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
|
||||||
|
"""
|
||||||
|
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
|
||||||
|
"""
|
||||||
|
wrapper = create_wrapper(target, "call_function")
|
||||||
|
return wrapper, target
|
||||||
|
|
||||||
|
|
||||||
def _proxies_to_metas(v):
|
def _proxies_to_metas(v):
|
||||||
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
|
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
|
||||||
if isinstance(v, MetaDeviceAttribute):
|
if isinstance(v, MetaDeviceAttribute):
|
||||||
@@ -712,25 +806,24 @@ def _proxies_to_metas(v):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def _gen_constructor_wrapper(target):
|
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||||
@functools.wraps(target)
|
global _CURRENT_TRACER
|
||||||
def wrapper(*args, **kwargs):
|
if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||||
proxy = None
|
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
|
||||||
|
return HFCacheProxy(n, _CURRENT_TRACER)
|
||||||
|
|
||||||
def check_has_proxy(v):
|
|
||||||
if isinstance(v, Proxy):
|
|
||||||
nonlocal proxy
|
|
||||||
proxy = v
|
|
||||||
|
|
||||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
|
||||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
|
||||||
|
ProxyableDynamicCache = HFProxyableClassMeta(
|
||||||
if proxy is not None:
|
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
|
||||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
)
|
||||||
else:
|
ProxyableSinkCache = HFProxyableClassMeta(
|
||||||
return target(*args, **kwargs)
|
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
|
||||||
|
)
|
||||||
return wrapper, target
|
ProxyableStaticCache = HFProxyableClassMeta(
|
||||||
|
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||||
@@ -764,6 +857,13 @@ class HFTracer(Tracer):
|
|||||||
"finfo",
|
"finfo",
|
||||||
"tril",
|
"tril",
|
||||||
]
|
]
|
||||||
|
_CLASSES_TO_PATCH = {
|
||||||
|
Cache: ProxyableCache,
|
||||||
|
DynamicCache: ProxyableDynamicCache,
|
||||||
|
SinkCache: ProxyableSinkCache,
|
||||||
|
StaticCache: ProxyableStaticCache,
|
||||||
|
}
|
||||||
|
|
||||||
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||||||
|
|
||||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||||
@@ -776,7 +876,7 @@ class HFTracer(Tracer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _generate_dummy_input(
|
def _generate_dummy_input(
|
||||||
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
|
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Generates dummy input for model inference recording."""
|
"""Generates dummy input for model inference recording."""
|
||||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||||
@@ -951,6 +1051,11 @@ class HFTracer(Tracer):
|
|||||||
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
|
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
|
||||||
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
|
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
|
||||||
|
|
||||||
|
should_install_metadata = True
|
||||||
|
|
||||||
|
self._disable_module_getattr = True
|
||||||
|
self._disable_call_module = True
|
||||||
|
|
||||||
if kind == "call_function":
|
if kind == "call_function":
|
||||||
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
||||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
@@ -963,39 +1068,36 @@ class HFTracer(Tracer):
|
|||||||
elif kind == "call_module":
|
elif kind == "call_module":
|
||||||
if not hasattr(self, "orig_forward"):
|
if not hasattr(self, "orig_forward"):
|
||||||
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||||||
self._disable_module_getattr = True
|
mod = self.root.get_submodule(target)
|
||||||
try:
|
mod_type = type(mod)
|
||||||
mod = self.root.get_submodule(target)
|
if mod_type in _MANUAL_META_OVERRIDES:
|
||||||
mod_type = type(mod)
|
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
|
||||||
if mod_type in _MANUAL_META_OVERRIDES:
|
else:
|
||||||
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
|
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||||
else:
|
|
||||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
|
||||||
finally:
|
|
||||||
self._disable_module_getattr = False
|
|
||||||
elif kind == "get_attr":
|
elif kind == "get_attr":
|
||||||
self._disable_module_getattr = True
|
attr_itr = self.root
|
||||||
try:
|
atoms = target.split(".")
|
||||||
attr_itr = self.root
|
for atom in atoms:
|
||||||
atoms = target.split(".")
|
attr_itr = getattr(attr_itr, atom)
|
||||||
for atom in atoms:
|
if isinstance(attr_itr, torch.Tensor):
|
||||||
attr_itr = getattr(attr_itr, atom)
|
meta_out = attr_itr.to(device="meta")
|
||||||
if isinstance(attr_itr, torch.Tensor):
|
else:
|
||||||
meta_out = attr_itr.to(device="meta")
|
meta_out = attr_itr
|
||||||
else:
|
|
||||||
meta_out = attr_itr
|
|
||||||
finally:
|
|
||||||
self._disable_module_getattr = False
|
|
||||||
else:
|
else:
|
||||||
return rv
|
should_install_metadata = False
|
||||||
|
|
||||||
|
if should_install_metadata:
|
||||||
|
if not isinstance(rv, Proxy):
|
||||||
|
raise ValueError("Don't support composite output yet")
|
||||||
|
rv.install_metadata(meta_out)
|
||||||
|
|
||||||
if not isinstance(rv, Proxy):
|
|
||||||
raise ValueError("Don't support composite output yet")
|
|
||||||
rv.install_metadata(meta_out)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if _IS_IN_DEBUG_MODE:
|
if _IS_IN_DEBUG_MODE:
|
||||||
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
|
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
|
||||||
|
|
||||||
|
self._disable_module_getattr = False
|
||||||
|
self._disable_call_module = False
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
# Replaced by .getattr from PyTorch 1.13
|
# Replaced by .getattr from PyTorch 1.13
|
||||||
@@ -1041,12 +1143,51 @@ class HFTracer(Tracer):
|
|||||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||||
|
|
||||||
def call_module(self, m, forward, args, kwargs):
|
def call_module(self, m, forward, args, kwargs):
|
||||||
|
if getattr(self, "_disable_call_module", False):
|
||||||
|
return forward(*args, **kwargs)
|
||||||
self.orig_forward = forward
|
self.orig_forward = forward
|
||||||
return super().call_module(m, forward, args, kwargs)
|
return super().call_module(m, forward, args, kwargs)
|
||||||
|
|
||||||
def proxy(self, node):
|
def proxy(self, node):
|
||||||
return HFProxy(node, self)
|
return HFProxy(node, self)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
|
||||||
|
# Patching torch functions
|
||||||
|
self.patched_torch_methods = {
|
||||||
|
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||||
|
}
|
||||||
|
self.orig_fns = set()
|
||||||
|
|
||||||
|
for name, (wrapper, orig) in self.patched_torch_methods.items():
|
||||||
|
setattr(torch, name, wrapper)
|
||||||
|
self.orig_fns.add(orig)
|
||||||
|
|
||||||
|
# Patching classes
|
||||||
|
patched = []
|
||||||
|
module_of_model = inspect.getmodule(root)
|
||||||
|
for name, mod in sys.modules.items():
|
||||||
|
if module_of_model is not None and mod is not module_of_model:
|
||||||
|
continue
|
||||||
|
if not name.startswith("transformers"):
|
||||||
|
continue
|
||||||
|
for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items():
|
||||||
|
for attr_name, attr in mod.__dict__.items():
|
||||||
|
if attr is orig_cls:
|
||||||
|
patched.append((mod, attr_name, orig_cls))
|
||||||
|
setattr(mod, attr_name, patched_cls)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Restoring patched functions and classes.
|
||||||
|
for name, (_, orig) in self.patched_torch_methods.items():
|
||||||
|
setattr(torch, name, orig)
|
||||||
|
self.patched_torch_methods = {}
|
||||||
|
self.orig_fns = set()
|
||||||
|
|
||||||
|
for mod, attr_name, orig_cls in patched:
|
||||||
|
setattr(mod, attr_name, orig_cls)
|
||||||
|
|
||||||
def trace(
|
def trace(
|
||||||
self,
|
self,
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
@@ -1125,28 +1266,25 @@ class HFTracer(Tracer):
|
|||||||
" transformers.PreTrainedModel."
|
" transformers.PreTrainedModel."
|
||||||
)
|
)
|
||||||
|
|
||||||
concrete_metas = {
|
def to_meta(value):
|
||||||
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
|
if isinstance(value, torch.Tensor):
|
||||||
for input_name, input_ in inputs.items()
|
return value.to("meta")
|
||||||
}
|
return value
|
||||||
|
|
||||||
|
concrete_metas = pytree.tree_map(to_meta, inputs)
|
||||||
|
|
||||||
for param in sig.parameters.values():
|
for param in sig.parameters.values():
|
||||||
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
|
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
|
||||||
concrete_metas[f"**{param.name}"] = {}
|
concrete_metas[f"**{param.name}"] = {}
|
||||||
self.meta_args = concrete_metas
|
self.meta_args = concrete_metas
|
||||||
self.patched_torch_methods = {
|
|
||||||
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
|
||||||
}
|
|
||||||
self.orig_fns = set()
|
|
||||||
|
|
||||||
for name, (wrapper, orig) in self.patched_torch_methods.items():
|
global _CURRENT_TRACER
|
||||||
setattr(torch, name, wrapper)
|
_CURRENT_TRACER = self
|
||||||
self.orig_fns.add(orig)
|
with self.patch_for_tracing(root):
|
||||||
|
try:
|
||||||
try:
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
finally:
|
||||||
finally:
|
_CURRENT_TRACER = None
|
||||||
for name, (_, orig) in self.patched_torch_methods.items():
|
|
||||||
setattr(torch, name, orig)
|
|
||||||
|
|
||||||
# This is necessary because concrete args are added as input to the traced module since
|
# This is necessary because concrete args are added as input to the traced module since
|
||||||
# https://github.com/pytorch/pytorch/pull/55888.
|
# https://github.com/pytorch/pytorch/pull/55888.
|
||||||
@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
|
|||||||
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||||
|
|
||||||
|
|
||||||
def is_model_supported(model: PreTrainedModel):
|
def is_model_supported(model: "PreTrainedModel"):
|
||||||
return model.__class__.__name__ in _SUPPORTED_MODELS
|
return model.__class__.__name__ in _SUPPORTED_MODELS
|
||||||
|
|
||||||
|
|
||||||
def check_if_model_is_supported(model: PreTrainedModel):
|
def check_if_model_is_supported(model: "PreTrainedModel"):
|
||||||
if not is_model_supported(model):
|
if not is_model_supported(model):
|
||||||
supported_model_names = ", ".join(_SUPPORTED_MODELS)
|
supported_model_names = ", ".join(_SUPPORTED_MODELS)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
def symbolic_trace(
|
def symbolic_trace(
|
||||||
model: PreTrainedModel,
|
model: "PreTrainedModel",
|
||||||
input_names: Optional[List[str]] = None,
|
input_names: Optional[List[str]] = None,
|
||||||
disable_check: bool = False,
|
disable_check: bool = False,
|
||||||
tracer_cls: Type[HFTracer] = HFTracer,
|
tracer_cls: Type[HFTracer] = HFTracer,
|
||||||
@@ -1307,6 +1445,18 @@ def symbolic_trace(
|
|||||||
if not disable_check:
|
if not disable_check:
|
||||||
check_if_model_is_supported(model)
|
check_if_model_is_supported(model)
|
||||||
|
|
||||||
|
if "past_key_values" in input_names and not getattr(model.config, "use_cache", False):
|
||||||
|
logger.warning(
|
||||||
|
"`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
|
||||||
|
"unexpected behavior."
|
||||||
|
)
|
||||||
|
if "past_key_values" not in input_names and getattr(model.config, "use_cache", False):
|
||||||
|
logger.warning(
|
||||||
|
"`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
|
||||||
|
"model.config.use_cache = False."
|
||||||
|
)
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
# Tracing.
|
# Tracing.
|
||||||
tracer = tracer_cls()
|
tracer = tracer_cls()
|
||||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import gc
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import pickle
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
|
|||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test that the model can be serialized and restored properly
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
|
||||||
try:
|
|
||||||
with open(pkl_file_name, "wb") as f:
|
|
||||||
pickle.dump(traced_model, f)
|
|
||||||
with open(pkl_file_name, "rb") as f:
|
|
||||||
loaded = pickle.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
|
||||||
|
|
||||||
loaded_output = loaded(**filtered_inputs)
|
|
||||||
loaded_output = flatten_output(loaded_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], loaded_output[i]),
|
|
||||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||||
self.clear_torch_jit_class_registry()
|
self.clear_torch_jit_class_registry()
|
||||||
|
|||||||
Reference in New Issue
Block a user