From 3f0b7d0fac38bf326db672a9f246eaa247ff266d Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 20 May 2025 15:54:04 +0200 Subject: [PATCH] Mamba2 remove unecessary test parameterization (#38227) --- tests/models/mamba2/test_modeling_mamba2.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index ee63e825e1..5777053923 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -15,8 +15,6 @@ import unittest -from parameterized import parameterized - from transformers import AutoTokenizer, Mamba2Config, is_torch_available from transformers.testing_utils import ( Expectations, @@ -362,14 +360,9 @@ class Mamba2IntegrationTest(unittest.TestCase): self.prompt = ("[INST]Write a hello world program in C++.",) @require_read_token - @parameterized.expand( - [ - (torch_device,), - ] - ) @slow @require_torch - def test_simple_generate(self, device): + def test_simple_generate(self): """ Simple generate test to avoid regressions. Note: state-spaces (cuda) implementation and pure torch implementation @@ -380,9 +373,9 @@ class Mamba2IntegrationTest(unittest.TestCase): tokenizer.pad_token_id = tokenizer.eos_token_id model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16) - model.to(device) + model.to(torch_device) input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to( - device + torch_device ) out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)