T5 compile compatibilty (#34089)

* this worked in normal generation, needs more tests

* fix almost all tests in t5

* nit

* longt5, umt5, mt5

* style

* udop, pix2struct

* more models

* fix some tests

* fix onnx tests

* tracing tests fixed

* compile enabled and tested for t5 models

* fix small bug in slow tests

* [run-slow] t5

* uncomment

* style

* update with new generation refactoring

* nit

* fix copies

* this is the fix, had to change t5 to fix copies

* update

* [run-slow] t5

* [run-slow] t5

* update

* add test for encoder only T5

* clean up after rebase

* fix pop2piano

* add comment

* style

* fix copies after rebase

* fix copies  missed this one
This commit is contained in:
Raushan Turganbay
2024-10-22 08:23:53 +02:00
committed by GitHub
parent 5077bc034f
commit 73d65e637b
22 changed files with 2744 additions and 1179 deletions

View File

@@ -40,6 +40,7 @@ if is_torch_fx_available():
if is_torch_available():
import torch
import torch.nn.functional as F
from transformers import (
AutoModelForSeq2SeqLM,
@@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small MT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/mt5-small"
def setUp(self):
self.model_tester = MT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
@@ -627,12 +631,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
@@ -647,7 +648,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"visual_feats",
"visual_pos",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
@@ -657,15 +657,12 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)
@@ -718,6 +715,41 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
# overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_config(self):
self.config_tester.run_common_tests()