Fixed torch.finfo issue with torch.fx (#20040)
This commit is contained in:
@@ -835,17 +835,14 @@ class ModelTesterMixin:
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
if (
|
||||
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
|
||||
and not hasattr(model.config, "problem_type")
|
||||
or model.config.problem_type is None
|
||||
if isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
@@ -871,20 +868,6 @@ class ModelTesterMixin:
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be TorchScripted
|
||||
try:
|
||||
scripted = torch.jit.script(traced_model)
|
||||
except Exception as e:
|
||||
self.fail(f"Could not TorchScript the traced model: {e}")
|
||||
scripted_output = scripted(**filtered_inputs)
|
||||
scripted_output = flatten_output(scripted_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], scripted_output[i]),
|
||||
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
|
||||
Reference in New Issue
Block a user