Fix Evolla and xLSTM tests (#39769)

* fix all evolla

* xlstm
This commit is contained in:
Cyril Vallez
2025-07-30 09:51:55 +02:00
committed by GitHub
parent ec4033457e
commit 67cfe11528
5 changed files with 57 additions and 71 deletions

View File

@@ -363,7 +363,7 @@ class EvollaModelIntegrationTest(TestCasePlus):
@cached_property
def default_processor(self):
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf", revision="refs/pr/11")
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf")
@require_bitsandbytes
@slow
@@ -382,16 +382,10 @@ class EvollaModelIntegrationTest(TestCasePlus):
model = EvollaForProteinText2Text.from_pretrained(
"westlake-repl/Evolla-10B-hf",
quantization_config=quantization_config,
device_map="auto",
device_map=torch_device,
)
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
# keep for debugging
for i, t in enumerate(generated_text):
t = bytes(t, "utf-8").decode("unicode_escape")
print(f"{i}:\n{t}\n")
self.assertIn("This protein", generated_text[0])
self.assertIn("purine", generated_text[0])

View File

@@ -201,6 +201,10 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="xLSTM cache is not iterable")
def test_multi_gpu_data_parallel_forward(self):
pass
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -260,13 +264,14 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch
@slow
@require_read_token
@unittest.skip("Model is fully broken currently")
class xLSTMIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "NX-AI/xLSTM-7b"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, legacy=False)
self.prompt = ("[INST]Write a hello world program in C++.",)
def test_simple_generate(self, device):
def test_simple_generate(self):
"""
Simple generate test to avoid regressions.
Note: state-spaces (cuda) implementation and pure torch implementation
@@ -276,10 +281,9 @@ class xLSTMIntegrationTest(unittest.TestCase):
tokenizer = self.tokenizer
tokenizer.pad_token_id = tokenizer.eos_token_id
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
model.to(device)
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
device
torch_device
)
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
@@ -300,7 +304,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
]
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
# batched generation
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
@@ -328,7 +332,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
]
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
# batched generation
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
@@ -355,7 +359,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.no_grad():
block = xLSTMBlock(config.to_xlstm_block_config(), layer_idx=0).to("cuda")
block = xLSTMBlock(config.to_xlstm_block_config()).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
block.train()