T5 compile compatibilty (#34089)
* this worked in normal generation, needs more tests * fix almost all tests in t5 * nit * longt5, umt5, mt5 * style * udop, pix2struct * more models * fix some tests * fix onnx tests * tracing tests fixed * compile enabled and tested for t5 models * fix small bug in slow tests * [run-slow] t5 * uncomment * style * update with new generation refactoring * nit * fix copies * this is the fix, had to change t5 to fix copies * update * [run-slow] t5 * [run-slow] t5 * update * add test for encoder only T5 * clean up after rebase * fix pop2piano * add comment * style * fix copies after rebase * fix copies missed this one
This commit is contained in:
committed by
GitHub
parent
5077bc034f
commit
73d65e637b
@@ -31,6 +31,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
@@ -574,6 +575,41 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
@@ -602,7 +638,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/longt5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=13,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ if is_torch_fx_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
@@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# The small MT5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google/mt5-small"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MT5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
|
||||
@@ -627,12 +631,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
@@ -647,7 +648,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
@@ -657,15 +657,12 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
@@ -718,6 +715,41 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
# overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
|
||||
@@ -620,7 +620,7 @@ class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/Pop2Piano_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -645,6 +646,41 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -44,6 +45,7 @@ if is_torch_fx_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -578,6 +580,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# The small T5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google-t5/t5-small"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = T5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||
@@ -630,12 +635,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
@@ -650,7 +652,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
@@ -660,15 +661,12 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
@@ -721,6 +719,41 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@@ -1482,6 +1515,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
self.assertEqual(512, dct["input_ids"].shape[1])
|
||||
@@ -1604,14 +1638,76 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
|
||||
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
# TODO: @arthur?
|
||||
# PR #31938 caused regression on this test which was fixed by PR #34089
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for "
|
||||
"permanent residence after the marriages, prosecutors say."
|
||||
"Liana Barrientos has been married 10 times, nine of them in the Bronx . Her husbands filed for "
|
||||
"permanent residence after the marriages, prosecutors say ."
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_compile_static_cache(self):
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"theory of relativity states that 1) the speed of light is constant in all inertial reference frames. the laws of physics are the same for all inertial reference frames.",
|
||||
"ketchup is my favorite condiment.",
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||
"theory of relativity is not hard to grasp.",
|
||||
"summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.",
|
||||
]
|
||||
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device)
|
||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_compile_static_cache_encoder(self):
|
||||
prompts = [
|
||||
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||
"theory of relativity is not hard to grasp.",
|
||||
"summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.",
|
||||
]
|
||||
model = T5EncoderModel.from_pretrained("google-t5/t5-small").to(torch_device)
|
||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
logits = model(**inputs)
|
||||
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
logits_compiled = model(**inputs)
|
||||
self.assertTrue(torch.allclose(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], atol=1e-5))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAsymmetricT5(unittest.TestCase):
|
||||
|
||||
@@ -37,6 +37,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor
|
||||
|
||||
@@ -348,6 +349,7 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
expected_arg_names = [
|
||||
"attention_mask",
|
||||
"bbox",
|
||||
"cache_position",
|
||||
"cross_attn_head_mask",
|
||||
"decoder_attention_mask",
|
||||
"decoder_head_mask",
|
||||
@@ -365,6 +367,43 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
expected_arg_names = sorted(expected_arg_names)
|
||||
self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names)
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
bbox=input_dict["bbox"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
bbox=input_dict["bbox"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
@unittest.skip(
|
||||
"Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!"
|
||||
)
|
||||
@@ -534,6 +573,41 @@ class UdopEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
@unittest.skip(
|
||||
"Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!"
|
||||
)
|
||||
|
||||
@@ -41,6 +41,7 @@ if is_torch_fx_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -316,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
# The small UMT5 model needs higher percentages for CPU/MP tests
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google/umt5-small"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = UMT5ModelTester(self)
|
||||
|
||||
@@ -486,6 +490,41 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
|
||||
def test_custom_4d_attention_mask(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
_,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
_,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_ids=input_dict["input_ids"][:3],
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
input_ids=input_dict["input_ids"][:1],
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_with_sequence_classification_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user