fix tests with main revision and read token (#33560)
* fix tests with main revision and read token * [run-slow]mamba2 * test previously skipped tests * [run-slow]mamba2 * skip some tests * [run-slow]mamba2 * finalize tests * [run-slow]mamba2
This commit is contained in:
@@ -20,7 +20,7 @@ from typing import Dict, List, Tuple
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -96,7 +96,7 @@ class Mamba2ModelTester:
|
|||||||
self.tie_word_embeddings = tie_word_embeddings
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
|
||||||
def get_large_model_config(self):
|
def get_large_model_config(self):
|
||||||
return Mamba2Config.from_pretrained("revision='refs/pr/9'")
|
return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
|
||||||
|
|
||||||
def prepare_config_and_inputs(
|
def prepare_config_and_inputs(
|
||||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||||
@@ -199,34 +199,26 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
|
||||||
def test_beam_sample_generate(self):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Initialization of mamba2 fails this")
|
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||||
def test_multi_gpu_data_parallel_forward(self):
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_model_outputs_equivalence(self):
|
def test_model_outputs_equivalence(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -292,12 +284,11 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
|
@require_read_token
|
||||||
class Mamba2IntegrationTest(unittest.TestCase):
|
class Mamba2IntegrationTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_id = "mistralai/Mamba-Codestral-7B-v0.1"
|
self.model_id = "mistralai/Mamba-Codestral-7B-v0.1"
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||||
self.model_id, revision="refs/pr/9", from_slow=True, legacy=False
|
|
||||||
)
|
|
||||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
@@ -317,7 +308,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
|
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||||
device
|
device
|
||||||
@@ -343,9 +334,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||||
]
|
]
|
||||||
|
|
||||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
|
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
# batched generation
|
# batched generation
|
||||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
@@ -375,9 +364,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||||
]
|
]
|
||||||
|
|
||||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
|
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
# batched generation
|
# batched generation
|
||||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user