From 45c7bfb1571160d2c06b880073a5c73e6bfa3677 Mon Sep 17 00:00:00 2001 From: Xibin Bayes Zhou Date: Sat, 26 Jul 2025 01:11:57 +0800 Subject: [PATCH] Add evolla rebase main (#36232) * add evolla * adding protein encoder part * add initial processing test * save processor * add docstring * add evolla processor * add two test * change vision to protein * change resampler to sequence_compressor * change vision to protein * initial update for llama * add initial update for llamaForCausalLM * add `test_processor`, `test_saprot_output`, `test_protein_encoder_output` * change evolla, but still working on it * add test_single_forward * pass test_attention_outputs * pass test_hidden_states_output * pass test_save_load and test_from_pretrained_no_checkpoint * pass test_cpu_offload * skip some tests * update new progress * skip test_model_is_small * pass test_model_weights_reload_no_missing_tied_weights * pass test_model_get_set_embeddings * pass test_cpu_offload * skip test_resize_embeddings * add pipeline_model_mapping * remote old setUp * pass processor save_pretrained and load_pretrained * remove pooling layer * pass test_inputs_embeds_matches_input_ids * pass test_model_is_small * pass test_attention_outputs * pass test_initialization * pass test_model_get_set_embeddings * pass test_single_forward * skip test_disk_offload_bin and test_disk_offload_safetensors * fix most tests * pass test_protein_encoder_output * remove useless code * add EvollaForProteinText2Text * pass test_saprot_output * pass all EvollaModelTest test and remove processor test * add processor test to its own file * skip is_training since esm skipped it and the saprot code causes error when setting is_training True * pass processor tests * solve all except config * pass most cases * change init * add doc to `configuration_evolla.py` * remove image_processing test * remove extra processor test * remove extra modules * remove extra modules * change all configs into one config * pass all evolla test * pass `make fixup` * update short summary * update Evolla-10B-hf * pass check_dummies.py and check_code_quality * fix `tests/models/auto/test_tokenization_auto.py::AutoTokenizerTest::test_model_name_edge_cases_in_mappings` * remove dummy codes * change format * fix llava issue * update format * update to solve llama3 access issue * update to make forward right * solve processor save load problem from instructblip solution * remove unexpected file * skip `test_generation_tester_mixin_inheritance` * add `test_single_forward_correct` and `test_inference_natural_language_protein_reasoning` * add `modular_evolla.py` * solved issue #36362 * run `make fixup` * update modular * solve float32 training * add fix * solve `utils/check_docstrings.py` * update * update * update * remove other files and replace sequential and einsum * add use case in document * update the models * update model * change some wrong code * Update src/transformers/models/evolla/modular_evolla.py Co-authored-by: Cyril Vallez * Update src/transformers/models/evolla/modular_evolla.py Co-authored-by: Cyril Vallez * Update src/transformers/models/evolla/modular_evolla.py Co-authored-by: Cyril Vallez * Update src/transformers/models/evolla/modular_evolla.py Co-authored-by: Cyril Vallez * fix issues mentioned in PR * update style and rearrange the placement * fix return_dict argument issue * solve SaProtConfig issue * Solve EvollaSaProtRotaryEmbedding issue * solve attention_mask issue * solve almosst all issues * make style * update config * remove unrelated pickle file * delete pickle files * fix config * simplify a lot * remove past k-v from encoder * continue work * style * skip it from init * fix init * fix init * simplify more * fill in docstrings * change test for generation * skip test * fix style --------- Co-authored-by: Chenchen Han <13980209828@163.com> Co-authored-by: Cyril Vallez Co-authored-by: Cyril Vallez --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/evolla.md | 95 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/processing_auto.py | 1 + src/transformers/models/evolla/__init__.py | 28 + .../models/evolla/configuration_evolla.py | 279 +++ .../models/evolla/modeling_evolla.py | 1761 +++++++++++++++++ .../models/evolla/modular_evolla.py | 1008 ++++++++++ .../models/evolla/processing_evolla.py | 247 +++ tests/models/evolla/__init__.py | 0 tests/models/evolla/test_modeling_evolla.py | 397 ++++ tests/models/evolla/test_processor_evolla.py | 295 +++ utils/check_repo.py | 1 + 15 files changed, 4120 insertions(+) create mode 100644 docs/source/en/model_doc/evolla.md create mode 100644 src/transformers/models/evolla/__init__.py create mode 100644 src/transformers/models/evolla/configuration_evolla.py create mode 100644 src/transformers/models/evolla/modeling_evolla.py create mode 100644 src/transformers/models/evolla/modular_evolla.py create mode 100644 src/transformers/models/evolla/processing_evolla.py create mode 100644 tests/models/evolla/__init__.py create mode 100644 tests/models/evolla/test_modeling_evolla.py create mode 100644 tests/models/evolla/test_processor_evolla.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c7fc602065..e317998a36 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -975,6 +975,8 @@ title: Donut - local: model_doc/emu3 title: Emu3 + - local: model_doc/evolla + title: Evolla - local: model_doc/flava title: FLAVA - local: model_doc/gemma3 diff --git a/docs/source/en/model_doc/evolla.md b/docs/source/en/model_doc/evolla.md new file mode 100644 index 0000000000..79c3b120cb --- /dev/null +++ b/docs/source/en/model_doc/evolla.md @@ -0,0 +1,95 @@ + + +# Evolla + +## Overview + +The Evolla model was proposed in [Decoding the Molecular Language of Proteins with Evolla](https://doi.org/10.1101/2025.01.05.630192) by [Zhou et al.](https://doi.org/10.1101/2025.01.05.630192). + +Evolla is an advanced 80-billion-parameter protein-language generative model designed to decode the molecular language of proteins. It integrates information from protein sequences, structures, and user queries to generate precise and contextually nuanced insights into protein function. Trained on an unprecedented AI-generated dataset of 546 million protein question-answer pairs and 150 billion word tokens, Evolla significantly advances research in proteomics and functional genomics, providing expert-level insights and shedding light on the molecular logic encoded in proteins. + +The abstract from the paper is the following: + +*Proteins, nature’s intricate molecular machines, are the products of billions of years of evolution and play fundamental roles in sustaining life. Yet, deciphering their molecular language - that is, understanding how protein sequences and structures encode and determine biological functions - remains a corner-stone challenge in modern biology. Here, we introduce Evolla, an 80 billion frontier protein-language generative model designed to decode the molecular language of proteins. By integrating information from protein sequences, structures, and user queries, Evolla generates precise and contextually nuanced insights into protein function. A key innovation of Evolla lies in its training on an unprecedented AI-generated dataset: 546 million protein question-answer pairs and 150 billion word tokens, designed to reflect the immense complexity and functional diversity of proteins. Post-pretraining, Evolla integrates Direct Preference Optimization (DPO) to refine the model based on preference signals and Retrieval-Augmented Generation (RAG) for external knowledge incorporation, improving response quality and relevance. To evaluate its performance, we propose a novel framework, Instructional Response Space (IRS), demonstrating that Evolla delivers expert-level insights, advancing research in proteomics and functional genomics while shedding light on the molecular logic encoded in proteins. The online demo is available at http://www.chat-protein.com/.* + +Examples: + +```python +processor = EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-DPO-hf") +model = EvollaForProteinText2Text.from_pretrained("westlake-repl/Evolla-10B-DPO-hf") +# aa_seq should have same length as foldseek +protein_inputs = [ + { + + "aa_seq": "MATGGRRG...", + "foldseek": "###lqpfd...", # hashtag means the low-confidence foldseek tokens + }, + { + "aa_seq": "MLPGLALL...", + "foldseek": "dfwwkwad...", + } +] +message_list = [ + [ + { + "role": "system", + "content": "You are an AI expert that can answer any questions about protein.", + }, + {"role": "user", "content": "What is the function of this protein?"}, + ], + [ + { + "role": "system", + "content": "You are an AI expert that can answer any questions about protein.", + }, + {"role": "user", "content": "What is the function of this protein?"}, + ] +] +input_dict = processor( + protein_informations, messages_list, return_tensors="pt", text_max_length=512, protein_max_length=1024 +) +with torch.no_grad(): + generated_ids = hf_model.generate(**input_dict) +generated_texts = processor.batch_decode( + generated_ids, skip_special_tokens=True +) +``` + +Tips: + +- This model was contributed by [Xibin Bayes Zhou](https://huggingface.co/XibinBayesZhou). +- The original code can be found [here](https://github.com/westlake-repl/Evolla). + + +## EvollaConfig + +[[autodoc]] EvollaConfig + +## EvollaModel + +[[autodoc]] EvollaModel + - forward + +## EvollaForProteinText2Text + +[[autodoc]] EvollaForProteinText2Text + - forward + +## EvollaProcessor + +[[autodoc]] EvollaProcessor + - __call__ diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 738d2ab83c..7b59f958f0 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -110,6 +110,7 @@ if TYPE_CHECKING: from .encoder_decoder import * from .ernie import * from .esm import * + from .evolla import * from .falcon import * from .falcon_h1 import * from .falcon_mamba import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 75a328d3cc..4d22bd00ef 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -133,6 +133,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("ernie4_5_moe", "Ernie4_5_MoeConfig"), ("ernie_m", "ErnieMConfig"), ("esm", "EsmConfig"), + ("evolla", "EvollaConfig"), ("falcon", "FalconConfig"), ("falcon_h1", "FalconH1Config"), ("falcon_mamba", "FalconMambaConfig"), @@ -528,6 +529,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("ernie4_5_moe", "Ernie4_5_MoE"), ("ernie_m", "ErnieM"), ("esm", "ESM"), + ("evolla", "Evolla"), ("falcon", "Falcon"), ("falcon3", "Falcon3"), ("falcon_h1", "FalconH1"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0574346832..9d6622f389 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -124,6 +124,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("ernie4_5_moe", "Ernie4_5_MoeModel"), ("ernie_m", "ErnieMModel"), ("esm", "EsmModel"), + ("evolla", "EvollaModel"), ("falcon", "FalconModel"), ("falcon_h1", "FalconH1Model"), ("falcon_mamba", "FalconMambaModel"), @@ -402,6 +403,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("distilbert", "DistilBertForMaskedLM"), ("electra", "ElectraForPreTraining"), ("ernie", "ErnieForPreTraining"), + ("evolla", "EvollaForProteinText2Text"), ("falcon_mamba", "FalconMambaForCausalLM"), ("flaubert", "FlaubertWithLMHeadModel"), ("flava", "FlavaForPreTraining"), @@ -934,6 +936,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( ("blip-2", "Blip2ForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"), ("emu3", "Emu3ForConditionalGeneration"), + ("evolla", "EvollaForProteinText2Text"), ("fuyu", "FuyuForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3n", "Gemma3nForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 69f65e23e1..31b798c805 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -64,6 +64,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("colqwen2", "ColQwen2Processor"), ("dia", "DiaProcessor"), ("emu3", "Emu3Processor"), + ("evolla", "EvollaProcessor"), ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), ("gemma3", "Gemma3Processor"), diff --git a/src/transformers/models/evolla/__init__.py b/src/transformers/models/evolla/__init__.py new file mode 100644 index 0000000000..09be74f033 --- /dev/null +++ b/src/transformers/models/evolla/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_evolla import * + from .modeling_evolla import * + from .processing_evolla import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/evolla/configuration_evolla.py b/src/transformers/models/evolla/configuration_evolla.py new file mode 100644 index 0000000000..18e12150f1 --- /dev/null +++ b/src/transformers/models/evolla/configuration_evolla.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evolla model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SaProtConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`EvollaSaProtProteinEncoder`]. It is used to instantiate a + SaProt model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 446): + Vocabulary size of the protein sequence model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`EvollaModel`]. + mask_token_id (`int`, *optional*, defaults to 4): + The id of the *mask* token in the protein sequence model. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the *padding* token in the protein sequence model. + hidden_size (`int`, *optional*, defaults to 1280): + Dimensionality of the protein sequence model layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 33): + Number of hidden layers in the protein sequence model. + num_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the protein sequence model. + intermediate_size (`int`, *optional*, defaults to 5120): + Dimensionality of the intermediate layers in the protein sequence model. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the hidden layers in the protein sequence model. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities in the protein sequence model. + max_position_embeddings (`int`, *optional*, defaults to 1026): + The maximum sequence length that the protein sequence model might ever be used with. Typically set this to + something large just in case (e.g., 512 or 1024 or 2048). + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for the layer normalization layer in the protein sequence model. + position_embedding_type (`str`, *optional*, defaults to `"rotary"`): + The type of position embedding to use in the protein sequence model. Currently only `"rotary"` is supported. + emb_layer_norm_before (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization before the position embedding in the protein sequence model. + token_dropout (`bool`, *optional*, defaults to `True`): + Whether to apply dropout to the tokens in the protein sequence model.""" + + def __init__( + self, + vocab_size=446, + mask_token_id=4, + pad_token_id=1, + hidden_size=1280, + num_hidden_layers=33, + num_attention_heads=20, + intermediate_size=5120, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1026, + initializer_range=0.02, + layer_norm_eps=1e-05, + position_embedding_type="rotary", + use_cache=True, + emb_layer_norm_before=False, + token_dropout=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.emb_layer_norm_before = emb_layer_norm_before + self.token_dropout = token_dropout + + +class EvollaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EvollaModel`]. It is used to instantiate an + Evolla model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Evolla-10B. + + e.g. [westlake-repl/Evolla-10B-hf](https://huggingface.co/westlake-repl/Evolla-10B-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + protein_encoder_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SaProtConfig`]. + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the Evolla llama model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`EvollaModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the llama layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimensionality of the intermediate layers in the llama model. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the llama model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the llama model. + num_key_value_heads (`int`, *optional*, defaults to 8): + Number of key-value pairs for each attention layer in the llama model. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the llama model. If string, `"gelu"`, `"relu"`, + `"selu"` and `"silu"` are supported. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for the RMS-norm layer in the llama model. + rope_theta (`float`, *optional*, defaults to 500000.0): + The threshold value for the RoPE layer in the llama model. + rope_scaling (`float`, *optional*): + The scaling factor for the RoPE layer in the llama model. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the attention layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention layer. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layer. + aligner_ffn_mult (`int`, *optional*, defaults to 4): + The FFN multiplier for the aligner layer. + aligner_enable_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the aligner layer. + aligner_attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities in the aligner layer. + aligner_num_add_layers (`int`, *optional*, defaults to 8): + The number of additional layers for the aligner layer. + resampler_depth (`int`, *optional*, defaults to 6): + The depth of the resampler layer in the llama model. + resampler_dim_head (`int`, *optional*, defaults to 64): + The dimension of the heads in the resampler layer in the llama model. + resampler_heads (`int`, *optional*, defaults to 8): + The number of heads in the resampler layer in the llama model. + resampler_num_latents (`int`, *optional*, defaults to 64): + The number of latents in the resampler layer in the llama model. + resampler_ff_mult (`int`, *optional*, defaults to 4): + The FFN multiplier for the resampler layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + bos_token_id (`int`, *optional*, defaults to 128000): + The id of the *beginning-of-sequence* token. + eos_token_id (`int`, *optional*, defaults to 128009): + The id of the *end-of-sequence* token. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the input and output word embeddings. + + Example: + + ```python + >>> from transformers import EvollaModel, EvollaConfig + + >>> # Initializing a Evolla evolla-10b style configuration + >>> configuration = EvollaConfig() + + >>> # Initializing a model from the evolla-10b style configuration + >>> model = EvollaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "EvollaModel" + sub_configs = {"protein_encoder_config": SaProtConfig} + + def __init__( + self, + protein_encoder_config=None, + vocab_size=128256, # llama vocab size + hidden_size=4096, # llama hidden size + intermediate_size=14336, # llama intermediate size + num_hidden_layers=32, # llama num layers + num_attention_heads=32, # llama num heads + num_key_value_heads=8, # llama num key-value heads + hidden_act="silu", # llama activation function + max_position_embeddings=8192, # llama rope max length + rms_norm_eps=1e-05, + rope_theta=500000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + aligner_ffn_mult=4, + aligner_enable_bias=True, + aligner_attention_probs_dropout_prob=0.1, + aligner_num_add_layers=8, + resampler_depth=6, + resampler_dim_head=64, + resampler_heads=8, + resampler_num_latents=64, + resampler_ff_mult=4, + initializer_range=0.02, + pad_token_id=None, + bos_token_id=128000, + eos_token_id=128009, + use_cache=False, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.aligner_ffn_mult = aligner_ffn_mult + self.aligner_enable_bias = aligner_enable_bias + self.aligner_attention_probs_dropout_prob = aligner_attention_probs_dropout_prob + self.aligner_num_add_layers = aligner_num_add_layers + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.resampler_depth = resampler_depth + self.resampler_dim_head = resampler_dim_head + self.resampler_heads = resampler_heads + self.resampler_num_latents = resampler_num_latents + self.resampler_ff_mult = resampler_ff_mult + + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # Subconfig + if protein_encoder_config is None: + protein_encoder_config = {} + logger.info("`protein_encoder_config` is `None`. Initializing the `SaProtConfig` with default values.") + self.protein_encoder_config = SaProtConfig(**protein_encoder_config) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["EvollaConfig"] diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py new file mode 100644 index 0000000000..f51f27d6d3 --- /dev/null +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -0,0 +1,1761 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/evolla/modular_evolla.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_evolla.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPast, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithPast, + ModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ( + ALL_ATTENTION_FUNCTIONS, + ModuleUtilsMixin, + PreTrainedModel, + find_pruneable_heads_and_indices, + get_parameter_dtype, + prune_linear_layer, +) +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_evolla import EvollaConfig, SaProtConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx + + +class EvollaSaProtEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + if config.emb_layer_norm_before: + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.layer_norm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + self.padding_idx = config.pad_token_id + if self.position_embedding_type == "absolute": + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + # remove the position_ids in EsmEmbeddings + self.position_ids = None + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support EVOLLA_SA_PROT-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: EVOLLA_SA_PROT has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all EVOLLA_SA_PROT model training runs + src_lengths = attention_mask.sum(-1) + mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths + embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to( + embeddings.dtype + ) + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +def rotate_half_esm(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_esm(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half_esm(x) * sin) + + +class EvollaSaProtRotaryEmbedding(nn.Module): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + inv_freq = inv_freq + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=2): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :] + self._sin_cached = emb.sin()[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached), + ) + + +class EvollaSaProtSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.config = config + + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", version="4.54.0") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2) + attention_mask = encoder_attention_mask + else: + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # EVOLLA_SA_PROT code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in EvollaSaProtModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (None,) + return outputs + + +class EvollaSaProtSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EvollaSaProtFlashAttention2(EvollaSaProtSelfAttention): + """ + EVOLLA_SA_PROT flash attention module. This module inherits from `EvollaSaProtSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + self.dropout_prob = config.attention_probs_dropout_prob + + @deprecate_kwarg("past_key_value", version="4.54.0") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + # Flash attention doesn't support output_attentions or cross attention + if output_attentions or head_mask is not None or encoder_hidden_states is not None: + logger.warning_once( + "EvollaSaProtFlashAttention2 does not support output_attentions, head_mask, or cross_attention. " + "Falling back to the manual attention implementation. This warning can be removed using " + 'the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + + bsz, q_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + input_dtype = query_layer.dtype + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_layer = query_layer.to(target_dtype) + key_layer = key_layer.to(target_dtype) + value_layer = value_layer.to(target_dtype) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # EVOLLA_SA_PROT code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings") + + # It would likely be faster to change self.transpose_for_scores to output the correct + # dimensions for flash_attention_2, but that would also mean changing the rotary embedding + # functions. Here we just permute the dimensions to match the expected input. + attn_output = _flash_attention_forward( + query_layer.permute(0, 2, 1, 3), + key_layer.permute(0, 2, 1, 3), + value_layer.permute(0, 2, 1, 3), + attention_mask, + query_length=q_len, + is_causal=self.is_decoder, + softmax_scale=1.0, + dropout=self.dropout_prob if self.training else 0.0, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + outputs = (attn_output, None) + if self.is_decoder: + outputs = outputs + (None,) + + return outputs + + +EVOLLA_SA_PROT_ATTENTION_CLASSES = { + "eager": EvollaSaProtSelfAttention, + "flash_attention_2": EvollaSaProtFlashAttention2, +} + + +class EvollaSaProtAttention(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.self = EVOLLA_SA_PROT_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.output = EvollaSaProtSelfOutput(config) + self.pruned_heads = set() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", version="4.54.0") + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + cache_position=None, + ): + hidden_states_ln = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_ln, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +def gelu(x): + """ + This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results. + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class EvollaSaProtIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = gelu(hidden_states) + return hidden_states + + +class EvollaSaProtOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EvollaSaProtLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = EvollaSaProtAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = EvollaSaProtAttention(config) + self.intermediate = EvollaSaProtIntermediate(config) + self.output = EvollaSaProtOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + @deprecate_kwarg("past_key_value", version="4.54.0") + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + cache_position=None, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + layer_output = self.feed_forward_chunk(attention_output) + + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (None,) + return outputs + + def feed_forward_chunk(self, attention_output): + attention_output_ln = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(attention_output_ln) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class EvollaSaProtEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([EvollaSaProtLayer(config) for _ in range(config.num_hidden_layers)]) + self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + @deprecate_kwarg("past_key_value", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + cache_position=None, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class EvollaSaProtPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class EvollaSaProtPreTrainedModel(PreTrainedModel): + config: SaProtConfig + _no_split_modules = ["EvollaSaProtLayer"] + _supports_flash_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): + def __init__(self, config: SaProtConfig): + super().__init__(config) + self.embeddings = EvollaSaProtEmbeddings(config) + self.encoder = EvollaSaProtEncoder(config) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask) + sequence_output = encoder_outputs[0] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: tuple[int], device: torch.device = None, dtype: torch.float = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = get_parameter_dtype(self) + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + +class EvollaSequenceCompressorAttention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents, mask): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D); n2: num of latent tokens + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk( + 2, dim=-1 + ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads + + q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3) + k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3) + v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3) + q = q * self.scale # batch_size, num_heads, num_latents, dim_head + + # attention + sim = torch.matmul(q, k.transpose(-1, -2)) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + bs, nh, skd, okd = sim.shape + ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd) + mask_exp = mask[:, None, None, :] + ones_exp = ones[None, :, :, None] + mask = mask_exp * ones_exp + + sim = sim.masked_fill((1 - mask).bool(), -1e4) + attn = sim.softmax(dim=-1) + out = torch.matmul(attn, v) + out = out.permute(0, 2, 1, 3) + + # [batch, seq, head, features] -> [batch, seq, head*features] + out = out.reshape(out.size(0), out.size(1), -1) + + return self.to_out(out) + + +class EvollaFeedForward(nn.Module): + def __init__(self, dim, mult=4): + super().__init__() + inner_dim = int(dim * mult) + + self.norm = nn.LayerNorm(dim) + self.fc1 = nn.Linear(dim, inner_dim, bias=False) + self.activation = nn.GELU() + self.fc2 = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + return self.fc2(self.activation(self.fc1(self.norm(x)))) + + +class EvollaSequenceCompressorResampler(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + protein_repr_dim = config.protein_encoder_config.hidden_size + self.num_latents = config.resampler_num_latents + self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True) + self.layers = nn.ModuleList([]) + for _ in range(config.resampler_depth): + self.layers.append( + nn.ModuleList( + [ + EvollaSequenceCompressorAttention( + dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads + ), + EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(config.hidden_size) + self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size) + + def forward(self, embeds, mask): + b = embeds.shape[0] + + bs, _ = mask.shape # bs, max_protein_length + latent_mask = torch.ones(bs, self.num_latents).to(mask.device) + mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents + + # blocks + ones = torch.ones(b).to(self.latents.device) + latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d] + latents = latents.to(embeds.dtype) + for attn, ff in self.layers: + latents = attn(embeds, latents, mask) + latents + latents = ff(latents) + latents + + transformed_feature = self.protein_projector(latents) + + return self.norm(transformed_feature) + + +@dataclass +@auto_docstring +class EvollaProteinEncoderModelOutput(ModelOutput): + sequence_compressor_output: torch.FloatTensor = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EvollaProteinEncoder(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config) + self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config) + + @can_return_tuple + def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs): + protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + protein_embeds = protein_output.last_hidden_state + sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask) + + return EvollaProteinEncoderModelOutput( + sequence_compressor_output=sequence_repr, + last_hidden_state=protein_output.last_hidden_state, + ) + + +class EvollaSequenceAlignerCrossAttention(nn.Module): + def __init__( + self, + config, + protein_encoder_dim: Optional[int] = None, + structure_encoder_dim: Optional[int] = None, + msa_encoder_dim: Optional[int] = None, + ): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.scale = self.num_attention_heads**-0.5 + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob + enable_bias = config.aligner_enable_bias + ffn_mult = config.aligner_ffn_mult + + self.query = nn.Linear(self.hidden_size, self.all_head_size) + if protein_encoder_dim is not None: + self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + else: + self.key_protein = None + self.value_protein = None + + if structure_encoder_dim is not None: + self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + else: + self.key_structure = None + self.value_structure = None + + if msa_encoder_dim is not None: + self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + else: + self.key_msa = None + self.value_msa = None + + self.attention_norm = EvollaRMSNorm(self.hidden_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias) + + self.ff = EvollaFeedForward(self.hidden_size, ffn_mult) + self.gate_attention = nn.Parameter(torch.tensor([0.0])) + self.gate_ffw = nn.Parameter(torch.tensor([0.0])) + + def cross_attention( + self, + query_states, + protein_key_value_states, + structure_key_value_states, + msa_key_value_states, + query_attn_mask, + protein_kv_attn_mask, + structure_kv_attn_mask, + msa_kv_attn_mask, + ): + """ + query_states: text + key_value_states: protein + query_states: [bs, query_seq_len, dim] + key_value_states: [bs, kv_seq_len, dim] + query_attn_mask: [bs, query_seq_len] + kv_attn_mask: [bs, kv_seq_len] + """ + + # Concatenate protein and structure + kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask] + kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None] + if not kv_attn_mask: + raise ValueError("At least one modality should be provided for cross attention.") + kv_attn_mask = torch.cat(kv_attn_mask, dim=1) + + query_layer = self.attention_norm(query_states) + + # Warning: This place might cause issues, refers to + # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13 + # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable + # Apply linear transformation to input_query, input_key, and input_value + query_layer = self.query(query_layer) # [bs, querylength, dim] + + if self.key_protein is not None and self.value_protein is not None: + protein_key_value_states = protein_key_value_states.to(query_states) + key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim] + value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim] + else: + key_layer_protein = None + value_layer_protein = None + + if self.key_structure is not None and self.value_structure is not None: + structure_key_value_states = structure_key_value_states.to(query_states) + key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim] + value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim] + else: + key_layer_structure = None + value_layer_structure = None + + if self.key_msa is not None and self.value_msa is not None: + msa_key_value_states = msa_key_value_states.to(query_states) + key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim] + value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim] + else: + key_layer_msa = None + value_layer_msa = None + + key_layer = [key_layer_protein, key_layer_structure, key_layer_msa] + key_layer = [_ for _ in key_layer if _ is not None] + key_layer = torch.cat(key_layer, dim=1) + + value_layer = [value_layer_protein, value_layer_structure, value_layer_msa] + value_layer = [_ for _ in value_layer if _ is not None] + value_layer = torch.cat(value_layer, dim=1) + + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3) + + new_key_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3) + + new_value_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3) + + query_layer = query_layer * self.scale + + # attention_mask: [bs, 1, querylength, keylength] + if query_attn_mask is None: + query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device) + attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :] + # Compute the scaled dot-product attention scores + attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength] + attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stablize score + attention_scores = attn_weights.masked_fill( + (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min + ) # [bs, numheads, querylength, keylength] + + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads] + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + return context_layer + + def forward( + self, + query_states, + protein_kv_states, + structure_kv_states, + msa_kv_states, + query_attn_mask, + protein_kv_attn_mask=None, + structure_kv_attn_mask=None, + msa_kv_attn_mask=None, + protein_batch_mask=None, + structure_batch_mask=None, + msa_batch_mask=None, + past_key_value=None, + ): + if protein_kv_states is not None: + bs, protein_kv_seq_len, dim = protein_kv_states.shape + if protein_kv_attn_mask is None: + protein_kv_attn_mask = ( + torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device) + * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T + ).to(protein_kv_states.device) + else: + protein_kv_attn_mask = None + + if structure_kv_states is not None: + bs, structure_kv_seq_len, dim = structure_kv_states.shape + if structure_kv_attn_mask is None: + structure_kv_attn_mask = ( + torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device) + * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T + ).to(structure_kv_states.device) + else: + structure_kv_attn_mask = None + + if msa_kv_states is not None: + bs, msa_kv_seq_len, dim = msa_kv_states.shape + if msa_kv_attn_mask is None: + msa_kv_attn_mask = ( + torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device) + * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T + ).to(msa_kv_states.device) + else: + msa_kv_attn_mask = None + hidden_states = query_states + # only when there's at least one valid modality, crossattention will be performed + if ( + (protein_kv_states is not None and protein_kv_attn_mask.any()) + or (structure_kv_states is not None and structure_kv_attn_mask.any()) + or (msa_kv_states is not None and msa_kv_attn_mask.any()) + ): + residual = hidden_states + hidden_states = self.cross_attention( + query_states=hidden_states, + protein_key_value_states=protein_kv_states, + structure_key_value_states=structure_kv_states, + msa_key_value_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_kv_attn_mask=protein_kv_attn_mask, + structure_kv_attn_mask=structure_kv_attn_mask, + msa_kv_attn_mask=msa_kv_attn_mask, + ) # [bs, query_seq_len, dim] + # tanh gate + hidden_states = torch.tanh(self.gate_attention) * hidden_states + + hidden_states = residual + hidden_states # input_query + + residual = hidden_states + hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw) + hidden_states = residual + hidden_states + + return hidden_states + + +@use_kernel_forward_from_hub("RMSNorm") +class EvollaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + EvollaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class EvollaRotaryEmbedding(nn.Module): + def __init__(self, config: EvollaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class EvollaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EvollaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: EvollaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EvollaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: EvollaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = EvollaAttention(config=config, layer_idx=layer_idx) + + self.mlp = EvollaMLP(config) + self.input_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0: + self.adapter = EvollaSequenceAlignerCrossAttention( + config, + protein_encoder_dim=config.hidden_size, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + protein_kv_states: Optional[torch.Tensor] = None, + structure_kv_states: Optional[torch.Tensor] = None, + msa_kv_states: Optional[torch.Tensor] = None, + protein_batch_mask: Optional[torch.Tensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + query_attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if hasattr(self, "adapter"): + hidden_states = self.adapter( + query_states=hidden_states, + protein_kv_states=protein_kv_states, + structure_kv_states=structure_kv_states, + msa_kv_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + ) + + return hidden_states + + +@auto_docstring +class EvollaPreTrainedModel(PreTrainedModel): + config: EvollaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["EvollaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _supports_static_cache = True + _supports_attention_backend = False + _can_record_outputs = { + "hidden_states": EvollaDecoderLayer, + "attentions": EvollaAttention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, EvollaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, EvollaSequenceAlignerCrossAttention): + module.gate_attention.zero_() + module.gate_ffw.zero_() + module.attention_norm.weight.data.fill_(1.0) + elif isinstance(module, EvollaSequenceCompressorResampler): + module.latents.data.normal_(mean=0.0, std=std) + + +class EvollaModel(EvollaPreTrainedModel): + def __init__(self, config: EvollaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + self.protein_encoder = EvollaProteinEncoder(config=config) + self.layers = nn.ModuleList( + [ + EvollaDecoderLayer( + config=config, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = EvollaRotaryEmbedding(config=config) + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + @check_model_inputs + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + protein_input_ids: Optional[torch.LongTensor] = None, + protein_attention_mask: Optional[torch.Tensor] = None, + structure_feats: Optional[torch.FloatTensor] = None, + msa_feats: Optional[torch.FloatTensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + structure_feats (torch.FloatTensor): + The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + msa_feats (torch.FloatTensor): + The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + structure_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now. + msa_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now. + """ + # If not provided `protein_feats`, use the `protein_encoder` to get the protein features + if protein_input_ids is not None and protein_attention_mask is not None: + protein_outputs = self.protein_encoder( + input_ids=protein_input_ids, + attention_mask=protein_attention_mask, + ) + protein_feats = protein_outputs.sequence_compressor_output + protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + protein_kv_states=protein_feats, + structure_kv_states=structure_feats, + msa_kv_states=msa_feats, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + query_attn_mask=attention_mask, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + return output + + +class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.model = EvollaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + return self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, # text input ids + attention_mask: Optional[torch.Tensor] = None, # text attention mask + inputs_embeds: Optional[torch.FloatTensor] = None, # text input embeddings + labels: Optional[torch.LongTensor] = None, + protein_input_ids: torch.LongTensor = None, + protein_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, + ): + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + + Example: + + ```python + >>> from transformers import EvollaProcessor, EvollaForProteinText2Text + >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf") + >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf") + + >>> protein_information = { + "aa_seq": "your amino acid sequence", + "foldseek": "your foldseek sequence", + } + >>> question = "What is the function of this protein?" + >>> message = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + + >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest") + >>> outputs = model.generate(**inputs) + + >>> print(processor.batch_decode(outputs, skip_special_tokens=True)) + ```""" + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + protein_input_ids=protein_input_ids, + protein_attention_mask=protein_attention_mask, + use_cache=use_cache, + **kwargs, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + lm_outputs = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return lm_outputs + + +__all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"] diff --git a/src/transformers/models/evolla/modular_evolla.py b/src/transformers/models/evolla/modular_evolla.py new file mode 100644 index 0000000000..30cf93b5c9 --- /dev/null +++ b/src/transformers/models/evolla/modular_evolla.py @@ -0,0 +1,1008 @@ +# coding=utf-8 +# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithPast, + ModelOutput, +) +from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype +from ...utils import ( + auto_docstring, + can_return_tuple, + logging, +) +from ...utils.generic import check_model_inputs +from ..esm.modeling_esm import ( + EsmAttention, + EsmEmbeddings, + EsmEncoder, + EsmIntermediate, + EsmLayer, + EsmOutput, + EsmPooler, + EsmSelfAttention, + EsmSelfOutput, +) +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from .configuration_evolla import EvollaConfig, SaProtConfig + + +logger = logging.get_logger(__name__) + + +class EvollaSaProtEmbeddings(EsmEmbeddings): + def __init__(self, config): + super().__init__() + # remove the position_ids in EsmEmbeddings + self.position_ids = None + + +def rotate_half_esm(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_esm(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half_esm(x) * sin) + + +class EvollaSaProtRotaryEmbedding(nn.Module): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + inv_freq = inv_freq + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=2): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :] + self._sin_cached = emb.sin()[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached), + ) + + +class EvollaSaProtSelfAttention(EsmSelfAttention, nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + nn.Module.__init__() + self.config = config + + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + +class EvollaSaProtSelfOutput(EsmSelfOutput): + pass + + +class EvollaSaProtAttention(EsmAttention): + pass + + +class EvollaSaProtIntermediate(EsmIntermediate): + pass + + +class EvollaSaProtOutput(EsmOutput): + pass + + +class EvollaSaProtLayer(EsmLayer): + pass + + +class EvollaSaProtEncoder(EsmEncoder): + pass + + +class EvollaSaProtPooler(EsmPooler): + pass + + +@auto_docstring +class EvollaSaProtPreTrainedModel(PreTrainedModel): + config: SaProtConfig + _no_split_modules = ["EvollaSaProtLayer"] + _supports_flash_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): + def __init__(self, config: SaProtConfig): + super().__init__(config) + self.embeddings = EvollaSaProtEmbeddings(config) + self.encoder = EvollaSaProtEncoder(config) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask) + sequence_output = encoder_outputs[0] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: tuple[int], device: torch.device = None, dtype: torch.float = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = get_parameter_dtype(self) + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + +class EvollaSequenceCompressorAttention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents, mask): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D); n2: num of latent tokens + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk( + 2, dim=-1 + ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads + + q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3) + k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3) + v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3) + q = q * self.scale # batch_size, num_heads, num_latents, dim_head + + # attention + sim = torch.matmul(q, k.transpose(-1, -2)) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + bs, nh, skd, okd = sim.shape + ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd) + mask_exp = mask[:, None, None, :] + ones_exp = ones[None, :, :, None] + mask = mask_exp * ones_exp + + sim = sim.masked_fill((1 - mask).bool(), -1e4) + attn = sim.softmax(dim=-1) + out = torch.matmul(attn, v) + out = out.permute(0, 2, 1, 3) + + # [batch, seq, head, features] -> [batch, seq, head*features] + out = out.reshape(out.size(0), out.size(1), -1) + + return self.to_out(out) + + +class EvollaFeedForward(nn.Module): + def __init__(self, dim, mult=4): + super().__init__() + inner_dim = int(dim * mult) + + self.norm = nn.LayerNorm(dim) + self.fc1 = nn.Linear(dim, inner_dim, bias=False) + self.activation = nn.GELU() + self.fc2 = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + return self.fc2(self.activation(self.fc1(self.norm(x)))) + + +class EvollaSequenceCompressorResampler(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + protein_repr_dim = config.protein_encoder_config.hidden_size + self.num_latents = config.resampler_num_latents + self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True) + self.layers = nn.ModuleList([]) + for _ in range(config.resampler_depth): + self.layers.append( + nn.ModuleList( + [ + EvollaSequenceCompressorAttention( + dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads + ), + EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(config.hidden_size) + self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size) + + def forward(self, embeds, mask): + b = embeds.shape[0] + + bs, _ = mask.shape # bs, max_protein_length + latent_mask = torch.ones(bs, self.num_latents).to(mask.device) + mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents + + # blocks + ones = torch.ones(b).to(self.latents.device) + latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d] + latents = latents.to(embeds.dtype) + for attn, ff in self.layers: + latents = attn(embeds, latents, mask) + latents + latents = ff(latents) + latents + + transformed_feature = self.protein_projector(latents) + + return self.norm(transformed_feature) + + +@dataclass +@auto_docstring +class EvollaProteinEncoderModelOutput(ModelOutput): + sequence_compressor_output: torch.FloatTensor = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EvollaProteinEncoder(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config) + self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config) + + @can_return_tuple + def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs): + protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + protein_embeds = protein_output.last_hidden_state + sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask) + + return EvollaProteinEncoderModelOutput( + sequence_compressor_output=sequence_repr, + last_hidden_state=protein_output.last_hidden_state, + ) + + +class EvollaSequenceAlignerCrossAttention(nn.Module): + def __init__( + self, + config, + protein_encoder_dim: Optional[int] = None, + structure_encoder_dim: Optional[int] = None, + msa_encoder_dim: Optional[int] = None, + ): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.scale = self.num_attention_heads**-0.5 + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob + enable_bias = config.aligner_enable_bias + ffn_mult = config.aligner_ffn_mult + + self.query = nn.Linear(self.hidden_size, self.all_head_size) + if protein_encoder_dim is not None: + self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + else: + self.key_protein = None + self.value_protein = None + + if structure_encoder_dim is not None: + self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + else: + self.key_structure = None + self.value_structure = None + + if msa_encoder_dim is not None: + self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + else: + self.key_msa = None + self.value_msa = None + + self.attention_norm = EvollaRMSNorm(self.hidden_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias) + + self.ff = EvollaFeedForward(self.hidden_size, ffn_mult) + self.gate_attention = nn.Parameter(torch.tensor([0.0])) + self.gate_ffw = nn.Parameter(torch.tensor([0.0])) + + def cross_attention( + self, + query_states, + protein_key_value_states, + structure_key_value_states, + msa_key_value_states, + query_attn_mask, + protein_kv_attn_mask, + structure_kv_attn_mask, + msa_kv_attn_mask, + ): + """ + query_states: text + key_value_states: protein + query_states: [bs, query_seq_len, dim] + key_value_states: [bs, kv_seq_len, dim] + query_attn_mask: [bs, query_seq_len] + kv_attn_mask: [bs, kv_seq_len] + """ + + # Concatenate protein and structure + kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask] + kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None] + if not kv_attn_mask: + raise ValueError("At least one modality should be provided for cross attention.") + kv_attn_mask = torch.cat(kv_attn_mask, dim=1) + + query_layer = self.attention_norm(query_states) + + # Warning: This place might cause issues, refers to + # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13 + # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable + # Apply linear transformation to input_query, input_key, and input_value + query_layer = self.query(query_layer) # [bs, querylength, dim] + + if self.key_protein is not None and self.value_protein is not None: + protein_key_value_states = protein_key_value_states.to(query_states) + key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim] + value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim] + else: + key_layer_protein = None + value_layer_protein = None + + if self.key_structure is not None and self.value_structure is not None: + structure_key_value_states = structure_key_value_states.to(query_states) + key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim] + value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim] + else: + key_layer_structure = None + value_layer_structure = None + + if self.key_msa is not None and self.value_msa is not None: + msa_key_value_states = msa_key_value_states.to(query_states) + key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim] + value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim] + else: + key_layer_msa = None + value_layer_msa = None + + key_layer = [key_layer_protein, key_layer_structure, key_layer_msa] + key_layer = [_ for _ in key_layer if _ is not None] + key_layer = torch.cat(key_layer, dim=1) + + value_layer = [value_layer_protein, value_layer_structure, value_layer_msa] + value_layer = [_ for _ in value_layer if _ is not None] + value_layer = torch.cat(value_layer, dim=1) + + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3) + + new_key_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3) + + new_value_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3) + + query_layer = query_layer * self.scale + + # attention_mask: [bs, 1, querylength, keylength] + if query_attn_mask is None: + query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device) + attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :] + # Compute the scaled dot-product attention scores + attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength] + attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stablize score + attention_scores = attn_weights.masked_fill( + (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min + ) # [bs, numheads, querylength, keylength] + + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads] + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + return context_layer + + def forward( + self, + query_states, + protein_kv_states, + structure_kv_states, + msa_kv_states, + query_attn_mask, + protein_kv_attn_mask=None, + structure_kv_attn_mask=None, + msa_kv_attn_mask=None, + protein_batch_mask=None, + structure_batch_mask=None, + msa_batch_mask=None, + past_key_value=None, + ): + if protein_kv_states is not None: + bs, protein_kv_seq_len, dim = protein_kv_states.shape + if protein_kv_attn_mask is None: + protein_kv_attn_mask = ( + torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device) + * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T + ).to(protein_kv_states.device) + else: + protein_kv_attn_mask = None + + if structure_kv_states is not None: + bs, structure_kv_seq_len, dim = structure_kv_states.shape + if structure_kv_attn_mask is None: + structure_kv_attn_mask = ( + torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device) + * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T + ).to(structure_kv_states.device) + else: + structure_kv_attn_mask = None + + if msa_kv_states is not None: + bs, msa_kv_seq_len, dim = msa_kv_states.shape + if msa_kv_attn_mask is None: + msa_kv_attn_mask = ( + torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device) + * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T + ).to(msa_kv_states.device) + else: + msa_kv_attn_mask = None + hidden_states = query_states + # only when there's at least one valid modality, crossattention will be performed + if ( + (protein_kv_states is not None and protein_kv_attn_mask.any()) + or (structure_kv_states is not None and structure_kv_attn_mask.any()) + or (msa_kv_states is not None and msa_kv_attn_mask.any()) + ): + residual = hidden_states + hidden_states = self.cross_attention( + query_states=hidden_states, + protein_key_value_states=protein_kv_states, + structure_key_value_states=structure_kv_states, + msa_key_value_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_kv_attn_mask=protein_kv_attn_mask, + structure_kv_attn_mask=structure_kv_attn_mask, + msa_kv_attn_mask=msa_kv_attn_mask, + ) # [bs, query_seq_len, dim] + # tanh gate + hidden_states = torch.tanh(self.gate_attention) * hidden_states + + hidden_states = residual + hidden_states # input_query + + residual = hidden_states + hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw) + hidden_states = residual + hidden_states + + return hidden_states + + +class EvollaRMSNorm(LlamaRMSNorm): + pass + + +class EvollaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class EvollaMLP(LlamaMLP): + pass + + +class EvollaAttention(LlamaAttention): + pass + + +class EvollaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: EvollaConfig, layer_idx: int): + super().__init__(config, layer_idx) + if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0: + self.adapter = EvollaSequenceAlignerCrossAttention( + config, + protein_encoder_dim=config.hidden_size, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + protein_kv_states: Optional[torch.Tensor] = None, + structure_kv_states: Optional[torch.Tensor] = None, + msa_kv_states: Optional[torch.Tensor] = None, + protein_batch_mask: Optional[torch.Tensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + query_attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if hasattr(self, "adapter"): + hidden_states = self.adapter( + query_states=hidden_states, + protein_kv_states=protein_kv_states, + structure_kv_states=structure_kv_states, + msa_kv_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + ) + + return hidden_states + + +class EvollaPreTrainedModel(LlamaPreTrainedModel): + _supports_attention_backend = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, EvollaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, EvollaSequenceAlignerCrossAttention): + module.gate_attention.zero_() + module.gate_ffw.zero_() + module.attention_norm.weight.data.fill_(1.0) + elif isinstance(module, EvollaSequenceCompressorResampler): + module.latents.data.normal_(mean=0.0, std=std) + + +class EvollaModel(EvollaPreTrainedModel): + def __init__(self, config: EvollaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + self.protein_encoder = EvollaProteinEncoder(config=config) + self.layers = nn.ModuleList( + [ + EvollaDecoderLayer( + config=config, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = EvollaRotaryEmbedding(config=config) + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + @check_model_inputs + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + protein_input_ids: Optional[torch.LongTensor] = None, + protein_attention_mask: Optional[torch.Tensor] = None, + structure_feats: Optional[torch.FloatTensor] = None, + msa_feats: Optional[torch.FloatTensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + structure_feats (torch.FloatTensor): + The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + msa_feats (torch.FloatTensor): + The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + structure_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now. + msa_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now. + """ + # If not provided `protein_feats`, use the `protein_encoder` to get the protein features + if protein_input_ids is not None and protein_attention_mask is not None: + protein_outputs = self.protein_encoder( + input_ids=protein_input_ids, + attention_mask=protein_attention_mask, + ) + protein_feats = protein_outputs.sequence_compressor_output + protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + protein_kv_states=protein_feats, + structure_kv_states=structure_feats, + msa_kv_states=msa_feats, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + query_attn_mask=attention_mask, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + return output + + +class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.model = EvollaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + return self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, # text input ids + attention_mask: Optional[torch.Tensor] = None, # text attention mask + inputs_embeds: Optional[torch.FloatTensor] = None, # text input embeddings + labels: Optional[torch.LongTensor] = None, + protein_input_ids: torch.LongTensor = None, + protein_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, + ): + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + + Example: + + ```python + >>> from transformers import EvollaProcessor, EvollaForProteinText2Text + >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf") + >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf") + + >>> protein_information = { + "aa_seq": "your amino acid sequence", + "foldseek": "your foldseek sequence", + } + >>> question = "What is the function of this protein?" + >>> message = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + + >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest") + >>> outputs = model.generate(**inputs) + + >>> print(processor.batch_decode(outputs, skip_special_tokens=True)) + ```""" + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + protein_input_ids=protein_input_ids, + protein_attention_mask=protein_attention_mask, + use_cache=use_cache, + **kwargs, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + lm_outputs = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return lm_outputs + + +__all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"] diff --git a/src/transformers/models/evolla/processing_evolla.py b/src/transformers/models/evolla/processing_evolla.py new file mode 100644 index 0000000000..d44981bff5 --- /dev/null +++ b/src/transformers/models/evolla/processing_evolla.py @@ -0,0 +1,247 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for EVOLLA. +""" + +import os +from typing import Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ( + ProcessorMixin, +) +from ..auto import AutoTokenizer + + +PROTEIN_VALID_KEYS = ["aa_seq", "foldseek", "msa"] + + +class EvollaProcessor(ProcessorMixin): + r""" + Constructs a EVOLLA processor which wraps a LLama tokenizer and SaProt tokenizer (EsmTokenizer) into a single processor. + + [`EvollaProcessor`] offers all the functionalities of [`EsmTokenizer`] and [`LlamaTokenizerFast`]. See the + docstring of [`~EvollaProcessor.__call__`] and [`~EvollaProcessor.decode`] for more information. + + Args: + protein_tokenizer (`EsmTokenizer`): + An instance of [`EsmTokenizer`]. The protein tokenizer is a required input. + tokenizer (`LlamaTokenizerFast`, *optional*): + An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. + protein_max_length (`int`, *optional*, defaults to 1024): + The maximum length of the sequence to be generated. + text_max_length (`int`, *optional*, defaults to 512): + The maximum length of the text to be generated. + """ + + attributes = ["protein_tokenizer", "tokenizer"] + valid_kwargs = ["sequence_max_length"] + # protein_tokenizer_class = "EsmTokenizer" + # tokenizer_class = "LlamaTokenizerFast" + protein_tokenizer_class = "AutoTokenizer" + tokenizer_class = "AutoTokenizer" + protein_tokenizer_dir_name = "protein_tokenizer" + # tokenizer_dir_name = "text_tokenizer" + + def __init__(self, protein_tokenizer, tokenizer=None, protein_max_length=1024, text_max_length=512, **kwargs): + if protein_tokenizer is None: + raise ValueError("You need to specify an `protein_tokenizer`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(protein_tokenizer, tokenizer) + + self.tokenizer.pad_token = "<|reserved_special_token_0|>" + self.protein_max_length = protein_max_length + self.text_max_length = text_max_length + + def process_proteins(self, proteins, protein_max_length=1024): + sa_sequences = [] + for protein in proteins: + aa_seq = protein.get("aa_seq") + foldseek = protein.get("foldseek") + sa_sequence = "".join([s.upper() + f.lower() for s, f in zip(aa_seq, foldseek)]) + sa_sequences.append(sa_sequence) + + sa_tokens = self.protein_tokenizer.batch_encode_plus( + sa_sequences, return_tensors="pt", truncation=True, max_length=protein_max_length, padding=True + ) + return sa_tokens + + def process_text( + self, + texts, + text_max_length: int = 512, + ): + prompts = [] + for messages in texts: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + prompts.append(prompt) + + prompt_inputs = self.tokenizer( + prompts, + add_special_tokens=False, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=text_max_length, + ) + return prompt_inputs + + def __call__( + self, + proteins: Optional[Union[list[dict], dict]] = None, + messages_list: Optional[Union[list[list[dict]], list[dict]]] = None, + protein_max_length: Optional[int] = None, + text_max_length: Optional[int] = None, + **kwargs, + ): + r"""This method takes batched or non-batched proteins and messages_list and converts them into format that can be used by + the model. + + Args: + proteins (`Union[List[dict], dict]`): + A list of dictionaries or a single dictionary containing the following keys: + - `"aa_seq"` (`str`) -- The amino acid sequence of the protein. + - `"foldseek"` (`str`) -- The foldseek string of the protein. + messages_list (`Union[List[List[dict]], List[dict]]`): + A list of lists of dictionaries or a list of dictionaries containing the following keys: + - `"role"` (`str`) -- The role of the message. + - `"content"` (`str`) -- The content of the message. + protein_max_length (`int`, *optional*, defaults to 1024): + The maximum length of the sequence to be generated. + text_max_length (`int`, *optional*, defaults to 512): + The maximum length of the text. + + Return: + a dict with following keys: + - `protein_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the protein sequence. + - `protein_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the protein sequence. + - `text_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the text sequence. + - `text_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the text sequence. + """ + # proteins and messages_list should be provided + if proteins is None or messages_list is None: + raise ValueError("You need to specify `messages_list` and `proteins`.") + + protein_max_length = protein_max_length if protein_max_length is not None else self.protein_max_length + text_max_length = text_max_length if text_max_length is not None else self.text_max_length + + # proteins should be List[dict] + if isinstance(proteins, dict): + proteins = [proteins] + # messages_list should be List[List[dict]] + if isinstance(messages_list, (list, tuple)) and not isinstance(messages_list[0], (list, tuple)): + messages_list = [messages_list] + # Check if batched proteins are in the correct format + if isinstance(proteins, (list, tuple)) and not all(isinstance(p, dict) for p in proteins): + raise ValueError("The proteins should be a list of dictionaries, but not all elements are dictionaries.") + if isinstance(proteins, (list, tuple)) and not all( + all(k in PROTEIN_VALID_KEYS for k in p.keys()) for p in proteins + ): + raise ValueError( + "There should be a list of dictionaries with keys: " + f"{', '.join(PROTEIN_VALID_KEYS)} for each protein." + f"But got: {proteins}" + ) + # Check if batched messages_list is in the correct format + if isinstance(messages_list, (list, tuple)): + for messages in messages_list: + if not isinstance(messages, (list, tuple)): + raise ValueError(f"Each messages in messages_list should be a list instead of {type(messages)}.") + if not all(isinstance(m, dict) for m in messages): + raise ValueError( + "Each message in messages_list should be a list of dictionaries, but not all elements are dictionaries." + ) + if any(len(m.keys()) != 2 for m in messages) or any( + set(m.keys()) != {"role", "content"} for m in messages + ): + raise ValueError( + "Each message in messages_list should be a list of dictionaries with two keys: 'role' and 'content'." + f"But got: {messages}" + ) + else: + raise ValueError( + f"The messages_list should be a list of lists of dictionaries, but it's {type(messages_list)}." + ) + sa_tokens = self.process_proteins(proteins, protein_max_length) + + text_tokens = self.process_text(messages_list, text_max_length) + + return BatchFeature( + data={ + "protein_input_ids": sa_tokens["input_ids"], + "protein_attention_mask": sa_tokens["attention_mask"], + "input_ids": text_tokens["input_ids"], + "attention_mask": text_tokens["attention_mask"], + } + ) + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + def protein_batch_decode(self, *args, **kwargs): + return self.protein_tokenizer.batch_decode(*args, **kwargs) + + def protein_decode(self, *args, **kwargs): + return self.protein_tokenizer.decode(*args, **kwargs) + + # overwrite to save the protein tokenizer in a separate folder + # Adapted from instructblip.processing_instructblip.py (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/models/instructblip/processing_instructblip.py#L191-L221) + def save_pretrained(self, save_directory, **kwargs): + # only save the protein tokenizer in sub_dir + self.protein_tokenizer.save_pretrained(os.path.join(save_directory, self.protein_tokenizer_dir_name)) + + # we modify the attributes so that only the text tokenizer are saved in the main folder + protein_tokenizer_present = "protein_tokenizer" in self.attributes + # find the correct position of it in the attributes list + protein_tokenizer_index = self.attributes.index("protein_tokenizer") if protein_tokenizer_present else None + if protein_tokenizer_present and protein_tokenizer_index is not None: + self.attributes.remove("protein_tokenizer") + + outputs = super().save_pretrained(save_directory, **kwargs) + + if protein_tokenizer_present and protein_tokenizer_index is not None: + self.attributes.insert(protein_tokenizer_index, "protein_tokenizer") + + return outputs + + # overwirte to load the protein tokenizer from a separate folder + # Adapted from instructblip.processing_instructblip.py (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/models/instructblip/processing_instructblip.py#L191-L221) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + # if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs' + if isinstance(processor, tuple): + processor = processor[0] + protein_tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder=cls.protein_tokenizer_dir_name + ) + + processor.protein_tokenizer = protein_tokenizer + + return processor + + +__all__ = ["EvollaProcessor"] diff --git a/tests/models/evolla/__init__.py b/tests/models/evolla/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/evolla/test_modeling_evolla.py b/tests/models/evolla/test_modeling_evolla.py new file mode 100644 index 0000000000..2864361077 --- /dev/null +++ b/tests/models/evolla/test_modeling_evolla.py @@ -0,0 +1,397 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Evolla model.""" + +import unittest + +from parameterized import parameterized + +from transformers import BitsAndBytesConfig, EvollaConfig, is_torch_available +from transformers.testing_utils import ( + TestCasePlus, + require_bitsandbytes, + require_torch, + require_torch_sdpa, + slow, + torch_device, +) +from transformers.utils import ( + cached_property, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _config_zero_init, + ids_tensor, + random_attention_mask, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import EvollaForProteinText2Text, EvollaModel, EvollaProcessor + + +class EvollaModelTester: + def __init__( + self, + parent, + batch_size=1, + is_training=False, + text_seq_length=20, + text_vocab_size=100, + protein_seq_length=10, + protein_vocab_size=20, + hidden_size=4, # llama hidden size + intermediate_size=7, # llama intermediate size + num_hidden_layers=1, # llama hidden layers + num_attention_heads=2, # llama attention heads + num_key_value_heads=2, # llama key value heads + protein_hidden_size=8, # protein encoder hidden size + protein_num_hidden_layers=1, # protein encoder hidden layers + protein_num_attention_heads=4, # protein encoder attention heads + protein_intermediate_size=11, # protein encoder intermediate size + resampler_num_latents=7, # sequence compressor num latents + resampler_ff_mult=1, # sequence compressor ff mult + resampler_depth=2, # sequence compressor depth + resampler_dim_head=4, # sequence compressor dim head + resampler_heads=2, # sequence compressor heads + aligner_num_add_layers=1, # sequence aligner num add layers + aligner_ffn_mult=1, # sequence aligner ffn mult + use_input_mask=True, + ): + self.parent = parent + self.batch_size = batch_size + self.protein_seq_length = protein_seq_length + self.protein_vocab_size = protein_vocab_size + self.text_seq_length = text_seq_length + self.text_vocab_size = text_vocab_size + self.seq_length = text_seq_length + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.protein_hidden_size = protein_hidden_size + self.protein_num_hidden_layers = protein_num_hidden_layers + self.protein_num_attention_heads = protein_num_attention_heads + self.protein_intermediate_size = protein_intermediate_size + + self.resampler_num_latents = resampler_num_latents + self.resampler_ff_mult = resampler_ff_mult + self.resampler_depth = resampler_depth + self.resampler_dim_head = resampler_dim_head + self.resampler_heads = resampler_heads + + self.aligner_num_add_layers = aligner_num_add_layers + self.aligner_ffn_mult = aligner_ffn_mult + + self.use_input_mask = use_input_mask + self.is_training = is_training + + @property + def is_encoder_decoder(self): + return False + + def prepare_config_and_inputs(self, num_proteins=None): + batch_size = num_proteins if num_proteins is not None else self.batch_size + text_input_ids = ids_tensor([batch_size, self.text_seq_length], self.text_vocab_size) + + protein_input_ids = ids_tensor([batch_size, self.protein_seq_length], self.protein_vocab_size) + + if self.use_input_mask: + text_input_mask = random_attention_mask([batch_size, self.text_seq_length]) + protein_input_mask = random_attention_mask([batch_size, self.protein_seq_length]) + + config = self.get_config() + return (config, text_input_ids, text_input_mask, protein_input_ids, protein_input_mask) + + def get_config(self): + return EvollaConfig( + protein_encoder_config={ + "vocab_size": self.protein_vocab_size, + "hidden_size": self.protein_hidden_size, + "num_hidden_layers": self.protein_num_hidden_layers, + "num_attention_heads": self.protein_num_attention_heads, + "intermediate_size": self.protein_intermediate_size, + }, + vocab_size=self.text_vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + aligner_ffn_mult=self.aligner_ffn_mult, + aligner_num_add_layers=self.aligner_num_add_layers, + resampler_depth=self.resampler_depth, + resampler_dim_head=self.resampler_dim_head, + resampler_heads=self.resampler_heads, + resampler_num_latents=self.resampler_num_latents, + resampler_ff_mult=self.resampler_ff_mult, + ) + + def create_and_check_model( + self, + config, + input_ids, + input_mask, + protein_input_ids, + protein_input_mask, + batch_size=None, + ): + batch_size = batch_size if batch_size is not None else self.batch_size + model = EvollaModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + protein_input_ids=protein_input_ids, + protein_attention_mask=protein_input_mask, + ) + self.parent.assertEqual(result.last_hidden_state.shape, (batch_size, input_ids.shape[1], self.hidden_size)) + + def create_and_check_model_gen( + self, + config, + input_ids, + input_mask, + protein_input_ids, + protein_input_mask, + ): + model = EvollaForProteinText2Text(config) + model.to(torch_device) + model.eval() + model.generate( + input_ids, + attention_mask=input_mask, + protein_input_ids=protein_input_ids, + protein_attention_mask=protein_input_mask, + max_length=self.seq_length + 2, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, text_input_ids, text_input_mask, protein_input_ids, protein_input_mask) = config_and_inputs + inputs_dict = { + "input_ids": text_input_ids, + "attention_mask": text_input_mask, + "protein_input_ids": protein_input_ids, + "protein_attention_mask": protein_input_mask, + } + return config, inputs_dict + + +@require_torch +class EvollaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (EvollaModel, EvollaForProteinText2Text) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": EvollaModel} if is_torch_available() else {} + test_pruning = False + test_headmasking = False + test_torchscript = False + test_resize_embeddings = False + maxDiff = None + + def setUp(self): + self.model_tester = EvollaModelTester(self) + self.config_tester = ConfigTester(self, config_class=EvollaConfig, hidden_size=37) + + @property + def is_encoder_decoder(self): + return self.model_tester.is_encoder_decoder + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + # XXX: EvollaForProteinText2Text has no MODEL_FOR group yet, but it should be the same + # as MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, so for now manually changing to do the right thing + # as super won't do it + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + + return inputs_dict + + def test_model_outputs_equivalence(self): + try: + orig = self.all_model_classes + # EvollaModel.forward doesn't have labels input arg - only EvollaForProteinText2Text does + self.all_model_classes = (EvollaForProteinText2Text,) if is_torch_available() else () + super().test_model_outputs_equivalence() + finally: + self.all_model_classes = orig + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model_single_protein(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(num_proteins=1) + self.model_tester.create_and_check_model(*config_and_inputs, batch_size=1) + + def test_model_multiple_proteins(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(num_proteins=2) + self.model_tester.create_and_check_model(*config_and_inputs, batch_size=2) + + def test_generate_single_protein(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(num_proteins=1) + self.model_tester.create_and_check_model_gen(*config_and_inputs) + + def test_generate_multiple_proteins(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(num_proteins=2) + self.model_tester.create_and_check_model_gen(*config_and_inputs) + + def test_saprot_output(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + protein_informations = { + "input_ids": inputs_dict["protein_input_ids"], + "attention_mask": inputs_dict["protein_attention_mask"], + } + for model_class in self.all_model_classes: + if model_class is not EvollaModel: + continue + model = model_class(config) + model.to(torch_device) + model.eval() + protein_encoder_outputs = model.protein_encoder.model(**protein_informations, return_dict=True) + print(model_class, protein_encoder_outputs) + + def test_protein_encoder_output(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + protein_informations = { + "input_ids": inputs_dict["protein_input_ids"], + "attention_mask": inputs_dict["protein_attention_mask"], + } + for model_class in self.all_model_classes: + if model_class is not EvollaModel: + continue + model = model_class(config) + model.to(torch_device) + model.eval() + protein_encoder_outputs = model.protein_encoder(**protein_informations, return_dict=True) + print(model_class, protein_encoder_outputs) + + def test_single_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + print(outputs) + + def test_initialization(self): + # we skip the latents initialization test + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # skip latents + if name.endswith("latents"): + print(f"Skipping latents {name}") + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + @unittest.skip("Evolla requires both text and protein inputs which is currently not done in this test.") + def test_eager_matches_sdpa_inference(self): + pass + + @unittest.skip("Evolla does not support eager attention implementation.") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip( + "Evolla has a separate test runner for generation tests with complex inheritance, causing this check to fail." + ) + def test_generation_tester_mixin_inheritance(self): + pass + + @unittest.skip("Evolla requires both text and protein inputs which is currently not done in this test.") + def test_flex_attention_with_grads(self): + pass + + +@require_torch +class EvollaModelIntegrationTest(TestCasePlus): + def _prepare_for_inputs(self): + aa_seq = "MLLEETLKSCPIVKRGKYHYFIHPISDGVPLVEPKLLREVATRIIKIGNFEGVNKIVTAEAMGIPLVTTLSLYTDIPYVIMRKREYKLPGEVPVFQSTGYSKGQLYLNGIEKGDKVIIIDDVISTGGTMIAIINALERAGAEIKDIICVIERGDGKKIVEEKTGYKIKTLVKIDVVDGEVVIL" + foldseek = "dvvvvqqqpfawdddppdtdgcgclapvpdpddpvvlvvllvlcvvpadpvqaqeeeeeddscpsnvvsncvvpvhyydywylddppdppkdwqwf######gitidpdqaaaheyeyeeaeqdqlrvvlsvvvrcvvrnyhhrayeyaeyhycnqvvccvvpvghyhynwywdqdpsgidtd" + question = "What is the function of this protein?" + + protein_information = { + "aa_seq": aa_seq, + "foldseek": foldseek, + } + messages = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + return protein_information, messages + + @cached_property + def default_processor(self): + return EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf", revision="refs/pr/11") + + @require_bitsandbytes + @slow + def test_inference_natural_language_protein_reasoning(self): + protein_information, messages = self._prepare_for_inputs() + processor = self.default_processor + inputs = processor( + messages_list=[messages], proteins=[protein_information], return_tensors="pt", padding="longest" + ).to(torch_device) + + # the CI gpu is small so using quantization to fit + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype="float16", + ) + model = EvollaForProteinText2Text.from_pretrained( + "westlake-repl/Evolla-10B-hf", + quantization_config=quantization_config, + device_map="auto", + ) + generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + + # keep for debugging + for i, t in enumerate(generated_text): + t = bytes(t, "utf-8").decode("unicode_escape") + print(f"{i}:\n{t}\n") + + self.assertIn("This protein", generated_text[0]) + + self.assertIn("purine", generated_text[0]) diff --git a/tests/models/evolla/test_processor_evolla.py b/tests/models/evolla/test_processor_evolla.py new file mode 100644 index 0000000000..0a1f1f3cd2 --- /dev/null +++ b/tests/models/evolla/test_processor_evolla.py @@ -0,0 +1,295 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import shutil +import tempfile +import unittest + +from transformers import ( + AutoProcessor, + EvollaProcessor, +) +from transformers.testing_utils import require_torch +from transformers.utils import is_torch_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_torch_available(): + import torch + + +EVOLLA_VALID_AA = list("ACDEFGHIKLMNPQRSTVWY#") +EVOLLA_VALID_FS = list("pynwrqhgdlvtmfsaeikc#") + + +@require_torch +class EvollaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = EvollaProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + processor = EvollaProcessor.from_pretrained("westlake-repl/Evolla-10B-hf") + + processor.save_pretrained(self.tmpdirname) + + self.input_keys = ["protein_input_ids", "protein_attention_mask", "input_ids", "attention_mask"] + + def prepare_input_and_expected_output(self): + amino_acid_sequence = "AAAA" + foldseek_sequence = "dddd" + question = "What is the function of this protein?" + + expected_output = { + "protein_input_ids": torch.tensor([[0, 13, 13, 13, 13, 2]]), + "protein_attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1]]), + "input_ids": torch.tensor( + [ + [ + 128000, + 128006, + 9125, + 128007, + 271, + 2675, + 527, + 459, + 15592, + 6335, + 430, + 649, + 4320, + 904, + 4860, + 922, + 13128, + 13, + 128009, + 128006, + 882, + 128007, + 271, + 3923, + 374, + 279, + 734, + 315, + 420, + 13128, + 30, + 128009, + 128006, + 78191, + 128007, + 271, + ] + ] + ), + "attention_mask": torch.tensor( + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ] + ), + } + protein_dict = {"aa_seq": amino_acid_sequence, "foldseek": foldseek_sequence} + message = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + return protein_dict, message, expected_output + + def test_processor(self): + protein_tokenizer = self.get_protein_tokenizer() + tokenizer = self.get_tokenizer() + + processor = EvollaProcessor(protein_tokenizer, tokenizer) + + protein_dict, message, expected_output = self.prepare_input_and_expected_output() + inputs = processor(proteins=[protein_dict], messages_list=[message]) + + # check if the input is correct + for key, value in expected_output.items(): + self.assertTrue( + torch.equal(inputs[key], value), + f"inputs[key] is {inputs[key]} and expected_output[key] is {expected_output[key]}", + ) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_protein_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).protein_tokenizer + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_inputs_single(self): + proteins = { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=100)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=100)), + } + return proteins + + def prepare_inputs_pair(self): + proteins = [ + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=100)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=100)), + }, + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=100)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=100)), + }, + ] + return proteins + + def prepare_inputs_long(self): + proteins = [ + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=100)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=100)), + }, + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=2000)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=2000)), + }, + ] + return proteins + + def prepare_inputs_short(self): + proteins = [ + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=1)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=1)), + }, + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=100)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=100)), + }, + ] + return proteins + + def prepare_inputs_empty(self): + proteins = [ + { + "aa_seq": "", + "foldseek": "", + }, + { + "aa_seq": "".join(random.choices(EVOLLA_VALID_AA, k=100)), + "foldseek": "".join(random.choices(EVOLLA_VALID_FS, k=100)), + }, + ] + return proteins + + def prepare_inputs(self, protein_types="pair"): + r""" + Prepare inputs for the test. + + Args: + protein_types (`str`): the types of proteins to prepare. + - "single": a single correct protein. + - "pair": a pair of correct proteins. + - "long": a long sequence of correct proteins and a correct protein. + - "short": a short sequence of correct proteins (only have 1 aa) and a correct protein. + - "empty": an empty sequence of proteins and a correct protein. + """ + if protein_types == "single": + proteins = self.prepare_inputs_single() + elif protein_types == "pair": + proteins = self.prepare_inputs_pair() + elif protein_types == "long": + proteins = self.prepare_inputs_long() + elif protein_types == "short": + proteins = self.prepare_inputs_short() + elif protein_types == "empty": + proteins = self.prepare_inputs_empty() + else: + raise ValueError( + f"protein_types should be one of 'single', 'pair', 'long','short', 'empty', but got {protein_types}" + ) + + questions = ["What is the function of the protein?"] * len(proteins) + messages_list = [] + for question in questions: + messages = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + messages_list.append(messages) + return proteins, messages_list + + def test_tokenizer_decode(self): + protein_tokenizer = self.get_protein_tokenizer() + tokenizer = self.get_tokenizer() + + processor = EvollaProcessor(tokenizer=tokenizer, protein_tokenizer=protein_tokenizer, return_tensors="pt") + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + protein_tokenizer = self.get_protein_tokenizer() + tokenizer = self.get_tokenizer() + + processor = EvollaProcessor(tokenizer=tokenizer, protein_tokenizer=protein_tokenizer) + proteins, messages_list = self.prepare_inputs() + + inputs = processor(messages_list=messages_list, proteins=proteins, padding="longest", return_tensors="pt") + + # For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask'] + self.assertSetEqual(set(inputs.keys()), set(self.input_keys)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5c79e0a228..01ed84939f 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -92,6 +92,7 @@ PRIVATE_MODELS = [ "Phi4MultimodalAudioModel", "Phi4MultimodalVisionModel", "Glm4vVisionModel", + "EvollaSaProtPreTrainedModel", ] # Update this list for models that are not tested with a comment explaining the reason it should not be.