@@ -1442,7 +1442,6 @@ class EvollaDecoderLayer(GradientCheckpointingLayer):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
protein_kv_states: Optional[torch.Tensor] = None,
|
protein_kv_states: Optional[torch.Tensor] = None,
|
||||||
@@ -1497,7 +1496,11 @@ class EvollaPreTrainedModel(PreTrainedModel):
|
|||||||
config: EvollaConfig
|
config: EvollaConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["EvollaDecoderLayer"]
|
_no_split_modules = [
|
||||||
|
"EvollaDecoderLayer",
|
||||||
|
"EvollaSequenceCompressorResampler",
|
||||||
|
"EvollaSequenceAlignerCrossAttention",
|
||||||
|
]
|
||||||
_skip_keys_device_placement = ["past_key_values"]
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
_supports_flash_attn = True
|
_supports_flash_attn = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
@@ -1512,20 +1515,8 @@ class EvollaPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
if isinstance(module, nn.Linear):
|
super()._init_weights(module)
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
if isinstance(module, EvollaSequenceAlignerCrossAttention):
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.padding_idx is not None:
|
|
||||||
module.weight.data[module.padding_idx].zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
elif isinstance(module, EvollaRMSNorm):
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
elif isinstance(module, EvollaSequenceAlignerCrossAttention):
|
|
||||||
module.gate_attention.zero_()
|
module.gate_attention.zero_()
|
||||||
module.gate_ffw.zero_()
|
module.gate_ffw.zero_()
|
||||||
module.attention_norm.weight.data.fill_(1.0)
|
module.attention_norm.weight.data.fill_(1.0)
|
||||||
@@ -1594,15 +1585,6 @@ class EvollaModel(EvollaPreTrainedModel):
|
|||||||
msa_batch_mask (torch.Tensor):
|
msa_batch_mask (torch.Tensor):
|
||||||
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
|
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
|
||||||
"""
|
"""
|
||||||
# If not provided `protein_feats`, use the `protein_encoder` to get the protein features
|
|
||||||
if protein_input_ids is not None and protein_attention_mask is not None:
|
|
||||||
protein_outputs = self.protein_encoder(
|
|
||||||
input_ids=protein_input_ids,
|
|
||||||
attention_mask=protein_attention_mask,
|
|
||||||
)
|
|
||||||
protein_feats = protein_outputs.sequence_compressor_output
|
|
||||||
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
@@ -1621,6 +1603,17 @@ class EvollaModel(EvollaPreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
protein_feats = None
|
||||||
|
protein_batch_mask = None
|
||||||
|
# If provided, actually compute them
|
||||||
|
if protein_input_ids is not None and protein_attention_mask is not None:
|
||||||
|
protein_outputs = self.protein_encoder(
|
||||||
|
input_ids=protein_input_ids,
|
||||||
|
attention_mask=protein_attention_mask,
|
||||||
|
)
|
||||||
|
protein_feats = protein_outputs.sequence_compressor_output
|
||||||
|
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
||||||
|
|
||||||
causal_mask = create_causal_mask(
|
causal_mask = create_causal_mask(
|
||||||
config=self.config,
|
config=self.config,
|
||||||
input_embeds=inputs_embeds,
|
input_embeds=inputs_embeds,
|
||||||
|
|||||||
@@ -717,7 +717,6 @@ class EvollaDecoderLayer(LlamaDecoderLayer):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
protein_kv_states: Optional[torch.Tensor] = None,
|
protein_kv_states: Optional[torch.Tensor] = None,
|
||||||
@@ -769,23 +768,16 @@ class EvollaDecoderLayer(LlamaDecoderLayer):
|
|||||||
|
|
||||||
class EvollaPreTrainedModel(LlamaPreTrainedModel):
|
class EvollaPreTrainedModel(LlamaPreTrainedModel):
|
||||||
_supports_attention_backend = False
|
_supports_attention_backend = False
|
||||||
|
_no_split_modules = [
|
||||||
|
"EvollaDecoderLayer",
|
||||||
|
"EvollaSequenceCompressorResampler",
|
||||||
|
"EvollaSequenceAlignerCrossAttention",
|
||||||
|
]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
if isinstance(module, nn.Linear):
|
LlamaPreTrainedModel._init_weights(module)
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
if isinstance(module, EvollaSequenceAlignerCrossAttention):
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.padding_idx is not None:
|
|
||||||
module.weight.data[module.padding_idx].zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
elif isinstance(module, EvollaRMSNorm):
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
elif isinstance(module, EvollaSequenceAlignerCrossAttention):
|
|
||||||
module.gate_attention.zero_()
|
module.gate_attention.zero_()
|
||||||
module.gate_ffw.zero_()
|
module.gate_ffw.zero_()
|
||||||
module.attention_norm.weight.data.fill_(1.0)
|
module.attention_norm.weight.data.fill_(1.0)
|
||||||
@@ -854,15 +846,6 @@ class EvollaModel(EvollaPreTrainedModel):
|
|||||||
msa_batch_mask (torch.Tensor):
|
msa_batch_mask (torch.Tensor):
|
||||||
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
|
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
|
||||||
"""
|
"""
|
||||||
# If not provided `protein_feats`, use the `protein_encoder` to get the protein features
|
|
||||||
if protein_input_ids is not None and protein_attention_mask is not None:
|
|
||||||
protein_outputs = self.protein_encoder(
|
|
||||||
input_ids=protein_input_ids,
|
|
||||||
attention_mask=protein_attention_mask,
|
|
||||||
)
|
|
||||||
protein_feats = protein_outputs.sequence_compressor_output
|
|
||||||
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
@@ -881,6 +864,17 @@ class EvollaModel(EvollaPreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
protein_feats = None
|
||||||
|
protein_batch_mask = None
|
||||||
|
# If provided, actually compute them
|
||||||
|
if protein_input_ids is not None and protein_attention_mask is not None:
|
||||||
|
protein_outputs = self.protein_encoder(
|
||||||
|
input_ids=protein_input_ids,
|
||||||
|
attention_mask=protein_attention_mask,
|
||||||
|
)
|
||||||
|
protein_feats = protein_outputs.sequence_compressor_output
|
||||||
|
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
||||||
|
|
||||||
causal_mask = create_causal_mask(
|
causal_mask = create_causal_mask(
|
||||||
config=self.config,
|
config=self.config,
|
||||||
input_embeds=inputs_embeds,
|
input_embeds=inputs_embeds,
|
||||||
|
|||||||
@@ -1037,17 +1037,17 @@ else:
|
|||||||
self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
|
self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
|
||||||
|
|
||||||
if self.config.weight_mode == "single":
|
if self.config.weight_mode == "single":
|
||||||
self.query = nn.Linear(
|
self.q = nn.Linear(
|
||||||
in_features=self.config.hidden_size,
|
in_features=self.config.hidden_size,
|
||||||
out_features=self.qk_dim,
|
out_features=self.qk_dim,
|
||||||
bias=self.config.use_bias,
|
bias=self.config.use_bias,
|
||||||
)
|
)
|
||||||
self.key = nn.Linear(
|
self.k = nn.Linear(
|
||||||
in_features=self.config.hidden_size,
|
in_features=self.config.hidden_size,
|
||||||
out_features=self.qk_dim,
|
out_features=self.qk_dim,
|
||||||
bias=self.config.use_bias,
|
bias=self.config.use_bias,
|
||||||
)
|
)
|
||||||
self.value = nn.Linear(
|
self.v = nn.Linear(
|
||||||
in_features=self.config.hidden_size,
|
in_features=self.config.hidden_size,
|
||||||
out_features=self.v_dim,
|
out_features=self.v_dim,
|
||||||
bias=self.config.use_bias,
|
bias=self.config.use_bias,
|
||||||
@@ -1104,9 +1104,9 @@ else:
|
|||||||
raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
|
raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
|
||||||
batch_size, sequence_length, _ = x.shape
|
batch_size, sequence_length, _ = x.shape
|
||||||
if self.config.weight_mode == "single":
|
if self.config.weight_mode == "single":
|
||||||
query = self.query(x)
|
query = self.q(x)
|
||||||
key = self.key(x)
|
key = self.k(x)
|
||||||
value = self.value(x)
|
value = self.v(x)
|
||||||
o_preact = self.ogate_preact(x)
|
o_preact = self.ogate_preact(x)
|
||||||
i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
|
i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
|
||||||
f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
|
f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
|
||||||
@@ -1535,6 +1535,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
|||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask=None, # not used but needed, otherwise generate complains when passing tokenizer inputs
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
cache_params: Optional[xLSTMCache] = None,
|
cache_params: Optional[xLSTMCache] = None,
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class EvollaModelIntegrationTest(TestCasePlus):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def default_processor(self):
|
def default_processor(self):
|
||||||
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf", revision="refs/pr/11")
|
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf")
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@slow
|
@slow
|
||||||
@@ -382,16 +382,10 @@ class EvollaModelIntegrationTest(TestCasePlus):
|
|||||||
model = EvollaForProteinText2Text.from_pretrained(
|
model = EvollaForProteinText2Text.from_pretrained(
|
||||||
"westlake-repl/Evolla-10B-hf",
|
"westlake-repl/Evolla-10B-hf",
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
device_map="auto",
|
device_map=torch_device,
|
||||||
)
|
)
|
||||||
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
# keep for debugging
|
|
||||||
for i, t in enumerate(generated_text):
|
|
||||||
t = bytes(t, "utf-8").decode("unicode_escape")
|
|
||||||
print(f"{i}:\n{t}\n")
|
|
||||||
|
|
||||||
self.assertIn("This protein", generated_text[0])
|
self.assertIn("This protein", generated_text[0])
|
||||||
|
|
||||||
self.assertIn("purine", generated_text[0])
|
self.assertIn("purine", generated_text[0])
|
||||||
|
|||||||
@@ -201,6 +201,10 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="xLSTM cache is not iterable")
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_model_outputs_equivalence(self):
|
def test_model_outputs_equivalence(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -260,13 +264,14 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
@unittest.skip("Model is fully broken currently")
|
||||||
class xLSTMIntegrationTest(unittest.TestCase):
|
class xLSTMIntegrationTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_id = "NX-AI/xLSTM-7b"
|
self.model_id = "NX-AI/xLSTM-7b"
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, legacy=False)
|
||||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||||
|
|
||||||
def test_simple_generate(self, device):
|
def test_simple_generate(self):
|
||||||
"""
|
"""
|
||||||
Simple generate test to avoid regressions.
|
Simple generate test to avoid regressions.
|
||||||
Note: state-spaces (cuda) implementation and pure torch implementation
|
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||||
@@ -276,10 +281,9 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||||
model.to(device)
|
|
||||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||||
device
|
torch_device
|
||||||
)
|
)
|
||||||
|
|
||||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||||
@@ -300,7 +304,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
|||||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||||
]
|
]
|
||||||
|
|
||||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
# batched generation
|
# batched generation
|
||||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
@@ -328,7 +332,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
|||||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||||
]
|
]
|
||||||
|
|
||||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
# batched generation
|
# batched generation
|
||||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
@@ -355,7 +359,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
|||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
block = xLSTMBlock(config.to_xlstm_block_config(), layer_idx=0).to("cuda")
|
block = xLSTMBlock(config.to_xlstm_block_config()).to("cuda")
|
||||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||||
|
|
||||||
block.train()
|
block.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user