T5 compile compatibilty (#34089)
* this worked in normal generation, needs more tests * fix almost all tests in t5 * nit * longt5, umt5, mt5 * style * udop, pix2struct * more models * fix some tests * fix onnx tests * tracing tests fixed * compile enabled and tested for t5 models * fix small bug in slow tests * [run-slow] t5 * uncomment * style * update with new generation refactoring * nit * fix copies * this is the fix, had to change t5 to fix copies * update * [run-slow] t5 * [run-slow] t5 * update * add test for encoder only T5 * clean up after rebase * fix pop2piano * add comment * style * fix copies after rebase * fix copies missed this one
This commit is contained in:
committed by
GitHub
parent
5077bc034f
commit
73d65e637b
@@ -41,6 +41,7 @@ if is_torch_fx_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -316,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
# The small UMT5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google/umt5-small"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = UMT5ModelTester(self)
|
||||
|
||||
@@ -486,6 +490,41 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_with_sequence_classification_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user