Fixed torch.finfo issue with torch.fx (#20040)

This commit is contained in:
Michael Benayoun
2022-11-03 16:14:44 +01:00
committed by GitHub
parent 6f257bb3c2
commit 9080607b2c
2 changed files with 31 additions and 28 deletions

View File

@@ -835,17 +835,14 @@ class ModelTesterMixin:
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
if (
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
and not hasattr(model.config, "problem_type")
or model.config.problem_type is None
if isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)
except Exception as e:
self.fail(f"Couldn't trace module: {e}")
@@ -871,20 +868,6 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be TorchScripted
try:
scripted = torch.jit.script(traced_model)
except Exception as e:
self.fail(f"Could not TorchScript the traced model: {e}")
scripted_output = scripted(**filtered_inputs)
scripted_output = flatten_output(scripted_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], scripted_output[i]),
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")