@@ -1442,7 +1442,6 @@ class EvollaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
protein_kv_states: Optional[torch.Tensor] = None,
|
||||
@@ -1497,7 +1496,11 @@ class EvollaPreTrainedModel(PreTrainedModel):
|
||||
config: EvollaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["EvollaDecoderLayer"]
|
||||
_no_split_modules = [
|
||||
"EvollaDecoderLayer",
|
||||
"EvollaSequenceCompressorResampler",
|
||||
"EvollaSequenceAlignerCrossAttention",
|
||||
]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
@@ -1512,20 +1515,8 @@ class EvollaPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, EvollaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, EvollaSequenceAlignerCrossAttention):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, EvollaSequenceAlignerCrossAttention):
|
||||
module.gate_attention.zero_()
|
||||
module.gate_ffw.zero_()
|
||||
module.attention_norm.weight.data.fill_(1.0)
|
||||
@@ -1594,15 +1585,6 @@ class EvollaModel(EvollaPreTrainedModel):
|
||||
msa_batch_mask (torch.Tensor):
|
||||
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
|
||||
"""
|
||||
# If not provided `protein_feats`, use the `protein_encoder` to get the protein features
|
||||
if protein_input_ids is not None and protein_attention_mask is not None:
|
||||
protein_outputs = self.protein_encoder(
|
||||
input_ids=protein_input_ids,
|
||||
attention_mask=protein_attention_mask,
|
||||
)
|
||||
protein_feats = protein_outputs.sequence_compressor_output
|
||||
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
@@ -1621,6 +1603,17 @@ class EvollaModel(EvollaPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
protein_feats = None
|
||||
protein_batch_mask = None
|
||||
# If provided, actually compute them
|
||||
if protein_input_ids is not None and protein_attention_mask is not None:
|
||||
protein_outputs = self.protein_encoder(
|
||||
input_ids=protein_input_ids,
|
||||
attention_mask=protein_attention_mask,
|
||||
)
|
||||
protein_feats = protein_outputs.sequence_compressor_output
|
||||
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
|
||||
@@ -717,7 +717,6 @@ class EvollaDecoderLayer(LlamaDecoderLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
protein_kv_states: Optional[torch.Tensor] = None,
|
||||
@@ -769,23 +768,16 @@ class EvollaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
class EvollaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_attention_backend = False
|
||||
_no_split_modules = [
|
||||
"EvollaDecoderLayer",
|
||||
"EvollaSequenceCompressorResampler",
|
||||
"EvollaSequenceAlignerCrossAttention",
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, EvollaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, EvollaSequenceAlignerCrossAttention):
|
||||
LlamaPreTrainedModel._init_weights(module)
|
||||
if isinstance(module, EvollaSequenceAlignerCrossAttention):
|
||||
module.gate_attention.zero_()
|
||||
module.gate_ffw.zero_()
|
||||
module.attention_norm.weight.data.fill_(1.0)
|
||||
@@ -854,15 +846,6 @@ class EvollaModel(EvollaPreTrainedModel):
|
||||
msa_batch_mask (torch.Tensor):
|
||||
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
|
||||
"""
|
||||
# If not provided `protein_feats`, use the `protein_encoder` to get the protein features
|
||||
if protein_input_ids is not None and protein_attention_mask is not None:
|
||||
protein_outputs = self.protein_encoder(
|
||||
input_ids=protein_input_ids,
|
||||
attention_mask=protein_attention_mask,
|
||||
)
|
||||
protein_feats = protein_outputs.sequence_compressor_output
|
||||
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
@@ -881,6 +864,17 @@ class EvollaModel(EvollaPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
protein_feats = None
|
||||
protein_batch_mask = None
|
||||
# If provided, actually compute them
|
||||
if protein_input_ids is not None and protein_attention_mask is not None:
|
||||
protein_outputs = self.protein_encoder(
|
||||
input_ids=protein_input_ids,
|
||||
attention_mask=protein_attention_mask,
|
||||
)
|
||||
protein_feats = protein_outputs.sequence_compressor_output
|
||||
protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
|
||||
@@ -1037,17 +1037,17 @@ else:
|
||||
self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
|
||||
|
||||
if self.config.weight_mode == "single":
|
||||
self.query = nn.Linear(
|
||||
self.q = nn.Linear(
|
||||
in_features=self.config.hidden_size,
|
||||
out_features=self.qk_dim,
|
||||
bias=self.config.use_bias,
|
||||
)
|
||||
self.key = nn.Linear(
|
||||
self.k = nn.Linear(
|
||||
in_features=self.config.hidden_size,
|
||||
out_features=self.qk_dim,
|
||||
bias=self.config.use_bias,
|
||||
)
|
||||
self.value = nn.Linear(
|
||||
self.v = nn.Linear(
|
||||
in_features=self.config.hidden_size,
|
||||
out_features=self.v_dim,
|
||||
bias=self.config.use_bias,
|
||||
@@ -1104,9 +1104,9 @@ else:
|
||||
raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
|
||||
batch_size, sequence_length, _ = x.shape
|
||||
if self.config.weight_mode == "single":
|
||||
query = self.query(x)
|
||||
key = self.key(x)
|
||||
value = self.value(x)
|
||||
query = self.q(x)
|
||||
key = self.k(x)
|
||||
value = self.v(x)
|
||||
o_preact = self.ogate_preact(x)
|
||||
i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
|
||||
f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
|
||||
@@ -1535,6 +1535,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None, # not used but needed, otherwise generate complains when passing tokenizer inputs
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
cache_params: Optional[xLSTMCache] = None,
|
||||
|
||||
@@ -363,7 +363,7 @@ class EvollaModelIntegrationTest(TestCasePlus):
|
||||
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf", revision="refs/pr/11")
|
||||
return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf")
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@@ -382,16 +382,10 @@ class EvollaModelIntegrationTest(TestCasePlus):
|
||||
model = EvollaForProteinText2Text.from_pretrained(
|
||||
"westlake-repl/Evolla-10B-hf",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto",
|
||||
device_map=torch_device,
|
||||
)
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# keep for debugging
|
||||
for i, t in enumerate(generated_text):
|
||||
t = bytes(t, "utf-8").decode("unicode_escape")
|
||||
print(f"{i}:\n{t}\n")
|
||||
|
||||
self.assertIn("This protein", generated_text[0])
|
||||
|
||||
self.assertIn("purine", generated_text[0])
|
||||
|
||||
@@ -201,6 +201,10 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="xLSTM cache is not iterable")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -260,13 +264,14 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
@require_torch
|
||||
@slow
|
||||
@require_read_token
|
||||
@unittest.skip("Model is fully broken currently")
|
||||
class xLSTMIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "NX-AI/xLSTM-7b"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, legacy=False)
|
||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||
|
||||
def test_simple_generate(self, device):
|
||||
def test_simple_generate(self):
|
||||
"""
|
||||
Simple generate test to avoid regressions.
|
||||
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||
@@ -276,10 +281,9 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
tokenizer = self.tokenizer
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||
model.to(device)
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||
device
|
||||
torch_device
|
||||
)
|
||||
|
||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||
@@ -300,7 +304,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
@@ -328,7 +332,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, device_map=torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
@@ -355,7 +359,7 @@ class xLSTMIntegrationTest(unittest.TestCase):
|
||||
torch.manual_seed(42)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||
with torch.no_grad():
|
||||
block = xLSTMBlock(config.to_xlstm_block_config(), layer_idx=0).to("cuda")
|
||||
block = xLSTMBlock(config.to_xlstm_block_config()).to("cuda")
|
||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||
|
||||
block.train()
|
||||
|
||||
Reference in New Issue
Block a user