@@ -363,7 +363,7 @@ class EvollaModelIntegrationTest(TestCasePlus):
|
||||
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf", revision="refs/pr/11")
|
||||
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf")
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@@ -382,16 +382,10 @@ class EvollaModelIntegrationTest(TestCasePlus):
|
||||
model = EvollaForProteinText2Text.from_pretrained(
|
||||
"westlake-repl/Evolla-10B-hf",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
device_map=torch_device,
|
||||
)
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# keep for debugging
|
||||
for i, t in enumerate(generated_text):
|
||||
t = bytes(t, "utf-8").decode("unicode_escape")
|
||||
print(f"{i}:\n{t}\n")
|
||||
|
||||
self.assertIn("This protein", generated_text[0])
|
||||
|
||||
self.assertIn("purine", generated_text[0])
|
||||
|
||||
@@ -201,6 +201,10 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="xLSTM cache is not iterable")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -260,13 +264,14 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
@require_torch
|
||||
@slow
|
||||
@require_read_token
|
||||
@unittest.skip("Model is fully broken currently")
|
||||
class xLSTMIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "NX-AI/xLSTM-7b"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, legacy=False)
|
||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||
|
||||
def test_simple_generate(self, device):
|
||||
def test_simple_generate(self):
|
||||
"""
|
||||
Simple generate test to avoid regressions.
|
||||
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||
@@ -276,10 +281,9 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
tokenizer = self.tokenizer
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||
model.to(device)
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||
device
|
||||
torch_device
|
||||
)
|
||||
|
||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||
@@ -300,7 +304,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
@@ -328,7 +332,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
@@ -355,7 +359,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
torch.manual_seed(42)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||
with torch.no_grad():
|
||||
block = xLSTMBlock(config.to_xlstm_block_config(), layer_idx=0).to("cuda")
|
||||
block = xLSTMBlock(config.to_xlstm_block_config()).to("cuda")
|
||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||
|
||||
block.train()
|
||||
|
||||
Reference in New Issue
Block a user