Tests: move generate tests to the right mixin and delete redundant tests (#34464)
* tmp commit * tmp commit * cull overwrites of deleted tests * typo * more specific docstring * make fixup * parameterize at the top? * correction * more deletions :D * tmp commit * for VLMs too * fix _check_outputs * test nit * make fixup * fix another flaky * test_generate_from_inputs_embeds -- handle missing attention mask
This commit is contained in:
@@ -28,7 +28,6 @@ from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -306,10 +305,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "THUDM/glm-4-9b"
|
||||
_torch_compile_test_revision = "refs/pr/15"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GlmModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37)
|
||||
@@ -426,41 +421,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""Overwrite the common test as the test is flaky on tiny models."""
|
||||
model = GlmForCausalLM.from_pretrained(
|
||||
"THUDM/glm-4-9b",
|
||||
device_map={"": 0},
|
||||
torch_dtype=torch.bfloat16,
|
||||
revision="refs/pr/15",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b", revision="refs/pr/15")
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
output_native = model.generate(**inputs, max_new_tokens=15, do_sample=False)
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = GlmForCausalLM.from_pretrained(
|
||||
"THUDM/glm-4-9b",
|
||||
device_map={"": 0},
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
revision="refs/pr/15",
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=15, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user