T5 compile compatibilty (#34089)
* this worked in normal generation, needs more tests * fix almost all tests in t5 * nit * longt5, umt5, mt5 * style * udop, pix2struct * more models * fix some tests * fix onnx tests * tracing tests fixed * compile enabled and tested for t5 models * fix small bug in slow tests * [run-slow] t5 * uncomment * style * update with new generation refactoring * nit * fix copies * this is the fix, had to change t5 to fix copies * update * [run-slow] t5 * [run-slow] t5 * update * add test for encoder only T5 * clean up after rebase * fix pop2piano * add comment * style * fix copies after rebase * fix copies missed this one
This commit is contained in:
committed by
GitHub
parent
5077bc034f
commit
73d65e637b
@@ -40,6 +40,7 @@ if is_torch_fx_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
@@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# The small MT5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google/mt5-small"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MT5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
|
||||
@@ -627,12 +631,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
@@ -647,7 +648,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
@@ -657,15 +657,12 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
@@ -718,6 +715,41 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
# overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user