From 4f0246e5359d2d9b7c9f1e5b7c3d7dc79fed0e99 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Thu, 19 Sep 2024 17:10:22 +0200 Subject: [PATCH] fix tests with main revision and read token (#33560) * fix tests with main revision and read token * [run-slow]mamba2 * test previously skipped tests * [run-slow]mamba2 * skip some tests * [run-slow]mamba2 * finalize tests * [run-slow]mamba2 --- tests/models/mamba2/test_modeling_mamba2.py | 37 +++++++-------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 276ecf2fd6..a1e2138d4d 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -20,7 +20,7 @@ from typing import Dict, List, Tuple from parameterized import parameterized from transformers import AutoTokenizer, Mamba2Config, is_torch_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -96,7 +96,7 @@ class Mamba2ModelTester: self.tie_word_embeddings = tie_word_embeddings def get_large_model_config(self): - return Mamba2Config.from_pretrained("revision='refs/pr/9'") + return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1") def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False @@ -199,34 +199,26 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_tied_weights_keys(self): pass - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_search_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_sample_generate(self): + @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") + def test_generate_without_input_ids(self): pass @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_generate_without_input_ids(self): + def test_generate_from_inputs_embeds_decoder_only(self): pass @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") def test_greedy_generate_dict_outputs_use_cache(self): pass - @unittest.skip(reason="Initialization of mamba2 fails this") - def test_save_load_fast_init_from_base(self): + @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") + def test_beam_search_generate_dict_outputs_use_cache(self): pass @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_generate_from_inputs_embeds_decoder_only(self): - pass - def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -292,12 +284,11 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix @require_torch @slow +@require_read_token class Mamba2IntegrationTest(unittest.TestCase): def setUp(self): self.model_id = "mistralai/Mamba-Codestral-7B-v0.1" - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id, revision="refs/pr/9", from_slow=True, legacy=False - ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False) self.prompt = ("[INST]Write a hello world program in C++.",) @parameterized.expand( @@ -317,7 +308,7 @@ class Mamba2IntegrationTest(unittest.TestCase): tokenizer = self.tokenizer tokenizer.pad_token_id = tokenizer.eos_token_id - model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16) + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16) model.to(device) input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to( device @@ -343,9 +334,7 @@ class Mamba2IntegrationTest(unittest.TestCase): "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", ] - model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to( - torch_device - ) + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) tokenizer.pad_token_id = tokenizer.eos_token_id # batched generation tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device) @@ -375,9 +364,7 @@ class Mamba2IntegrationTest(unittest.TestCase): "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", ] - model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to( - torch_device - ) + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) tokenizer.pad_token_id = tokenizer.eos_token_id # batched generation tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)