A cleaner and more scalable implementation of symbolic tracing (#11763)

Cleaner and more scalable implementation of symbolic tracing with torch.fx, and provides support for new architectures:
- ALBERT
- DistilBERT
- MobileBERT
- MegatronBERT
- GPT2
- GPT Neo

Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun
2021-05-20 18:02:29 +02:00
committed by GitHub
parent 469384a777
commit f4a0d6ff86
8 changed files with 260 additions and 114 deletions

View File

@@ -600,9 +600,9 @@ class ModelTesterMixin:
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None:
input_names.append("labels")
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
model_output = model(**prepared_inputs)
model_output = model(**filtered_inputs)
batch_size = input_ids.shape[0]
encoder_sequence_length = input_ids.shape[1]
@@ -615,26 +615,37 @@ class ModelTesterMixin:
sequence_length=[encoder_sequence_length, decoder_sequence_length],
)
traced_output = traced_model(**prepared_inputs)
traced_output = traced_model(**filtered_inputs)
else:
input_ids = inputs["input_ids"]
labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
model_output = model(**prepared_inputs)
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys()
batch_size = input_ids.shape[0]
model_output = model(**filtered_inputs)
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
sequence_length = input_ids.shape[2]
num_choices = input_ids.shape[1]
else:
sequence_length = input_ids.shape[1]
rank = len(input_ids.shape)
if rank == 2:
batch_size, sequence_length = input_ids.shape
num_choices = -1
elif rank == 3:
batch_size, num_choices, sequence_length = input_ids.shape
else:
raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
)
traced_model = symbolic_trace(
model,
@@ -643,14 +654,31 @@ class ModelTesterMixin:
sequence_length=sequence_length,
num_choices=num_choices,
)
traced_output = traced_model(**prepared_inputs)
traced_output = traced_model(**filtered_inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs))
self.assertTrue(outputs_are_close)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
def test_headmasking(self):
if not self.test_head_masking: