Mamba2 remove unecessary test parameterization (#38227)
This commit is contained in:
@@ -15,8 +15,6 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
@@ -362,14 +360,9 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@parameterized.expand(
|
|
||||||
[
|
|
||||||
(torch_device,),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_simple_generate(self, device):
|
def test_simple_generate(self):
|
||||||
"""
|
"""
|
||||||
Simple generate test to avoid regressions.
|
Simple generate test to avoid regressions.
|
||||||
Note: state-spaces (cuda) implementation and pure torch implementation
|
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||||
@@ -380,9 +373,9 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||||
model.to(device)
|
model.to(torch_device)
|
||||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||||
device
|
torch_device
|
||||||
)
|
)
|
||||||
|
|
||||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||||
|
|||||||
Reference in New Issue
Block a user