A cleaner and more scalable implementation of symbolic tracing (#11763)
Cleaner and more scalable implementation of symbolic tracing with torch.fx, and provides support for new architectures: - ALBERT - DistilBERT - MobileBERT - MegatronBERT - GPT2 - GPT Neo Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -229,6 +229,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
|
||||
@@ -600,9 +600,9 @@ class ModelTesterMixin:
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
encoder_sequence_length = input_ids.shape[1]
|
||||
@@ -615,26 +615,37 @@ class ModelTesterMixin:
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length],
|
||||
)
|
||||
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
else:
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
sequence_length = input_ids.shape[2]
|
||||
num_choices = input_ids.shape[1]
|
||||
else:
|
||||
sequence_length = input_ids.shape[1]
|
||||
rank = len(input_ids.shape)
|
||||
if rank == 2:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
num_choices = -1
|
||||
elif rank == 3:
|
||||
batch_size, num_choices, sequence_length = input_ids.shape
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
||||
)
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
@@ -643,14 +654,31 @@ class ModelTesterMixin:
|
||||
sequence_length=sequence_length,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs))
|
||||
self.assertTrue(outputs_are_close)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
|
||||
@@ -208,6 +208,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else None
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
|
||||
@@ -399,6 +399,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_missing_keys = False
|
||||
test_model_parallel = True
|
||||
|
||||
|
||||
@@ -276,6 +276,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
|
||||
all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_model_parallel = False
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
@@ -282,6 +281,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
|
||||
# test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
@@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
Reference in New Issue
Block a user