Fixed torch.finfo issue with torch.fx (#20040)
This commit is contained in:
@@ -230,6 +230,15 @@ def torch_arange(*args, **kwargs):
|
|||||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_full(*args, **kwargs):
|
||||||
|
args = list(args)
|
||||||
|
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
|
||||||
|
args[1] = 1 # Any value.
|
||||||
|
kwargs_without_device = dict(kwargs)
|
||||||
|
kwargs_without_device.pop("device", None)
|
||||||
|
return torch.full(*args, **kwargs_without_device)
|
||||||
|
|
||||||
|
|
||||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||||
if dim is None and axis is None:
|
if dim is None and axis is None:
|
||||||
dim = 0
|
dim = 0
|
||||||
@@ -509,6 +518,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
|||||||
torch.where: torch_where,
|
torch.where: torch_where,
|
||||||
torch.abs: torch_abs,
|
torch.abs: torch_abs,
|
||||||
torch.arange: torch_arange,
|
torch.arange: torch_arange,
|
||||||
|
torch.full: torch_full,
|
||||||
torch.cat: torch_cat,
|
torch.cat: torch_cat,
|
||||||
torch.stack: torch_stack,
|
torch.stack: torch_stack,
|
||||||
torch.add: torch_add,
|
torch.add: torch_add,
|
||||||
@@ -552,12 +562,6 @@ class HFProxy(Proxy):
|
|||||||
def shape(self):
|
def shape(self):
|
||||||
return self.tracer.create_proxy("call_method", "size", (self,), {})
|
return self.tracer.create_proxy("call_method", "size", (self,), {})
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
if hasattr(self, "_metadata") and self._metadata is not None:
|
|
||||||
return self._metadata.dtype
|
|
||||||
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
# Hack so we can track when devices are used. During meta-tensor propagation,
|
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||||||
@@ -597,12 +601,15 @@ class HFAttribute(HFProxy):
|
|||||||
self.tracer = root.tracer
|
self.tracer = root.tracer
|
||||||
self._node = None
|
self._node = None
|
||||||
|
|
||||||
|
if hasattr(self.root, "_metadata"):
|
||||||
|
self.install_metadata(getattr(self.root._metadata, attr))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node(self):
|
def node(self):
|
||||||
# the node for attributes is added lazily, since most will just be method calls
|
# the node for attributes is added lazily, since most will just be method calls
|
||||||
# which do not rely on the getitem call
|
# which do not rely on the getitem call
|
||||||
if self._node is None:
|
if self._node is None:
|
||||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
|
||||||
return self._node
|
return self._node
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
@@ -663,7 +670,18 @@ class HFTracer(Tracer):
|
|||||||
# Feature flag for proxying accesses to buffer values
|
# Feature flag for proxying accesses to buffer values
|
||||||
proxy_buffer_attributes: bool = True
|
proxy_buffer_attributes: bool = True
|
||||||
allow_insert_stateless_mods: bool = True
|
allow_insert_stateless_mods: bool = True
|
||||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
_TORCH_METHODS_TO_PATCH = [
|
||||||
|
"arange",
|
||||||
|
"zeros",
|
||||||
|
"ones",
|
||||||
|
"full",
|
||||||
|
"full_like",
|
||||||
|
"eye",
|
||||||
|
"empty",
|
||||||
|
"tensor",
|
||||||
|
"clamp",
|
||||||
|
"finfo",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||||
|
|
||||||
@@ -737,6 +755,8 @@ class HFTracer(Tracer):
|
|||||||
"GPT2DoubleHeadsModel",
|
"GPT2DoubleHeadsModel",
|
||||||
]:
|
]:
|
||||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
|
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
|
||||||
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
|
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
|
||||||
|
|||||||
@@ -835,17 +835,14 @@ class ModelTesterMixin:
|
|||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||||
input_names = list(filtered_inputs.keys())
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
if isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) and (
|
||||||
|
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||||
if (
|
|
||||||
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
|
|
||||||
and not hasattr(model.config, "problem_type")
|
|
||||||
or model.config.problem_type is None
|
|
||||||
):
|
):
|
||||||
model.config.problem_type = "single_label_classification"
|
model.config.problem_type = "single_label_classification"
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
traced_model = symbolic_trace(model, input_names)
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
self.fail(f"Couldn't trace module: {e}")
|
||||||
@@ -871,20 +868,6 @@ class ModelTesterMixin:
|
|||||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test that the model can be TorchScripted
|
|
||||||
try:
|
|
||||||
scripted = torch.jit.script(traced_model)
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Could not TorchScript the traced model: {e}")
|
|
||||||
scripted_output = scripted(**filtered_inputs)
|
|
||||||
scripted_output = flatten_output(scripted_output)
|
|
||||||
|
|
||||||
for i in range(num_outputs):
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(model_output[i], scripted_output[i]),
|
|
||||||
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the model can be serialized and restored properly
|
# Test that the model can be serialized and restored properly
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||||
|
|||||||
Reference in New Issue
Block a user