Add jamba (#29943)
* Add jamba arch * apply "make fix-copies" changes * fix link to model in JambaConfig docstring * Add n_ctx in modeling file because repo-consistency wants that * Add jamba to flash attention and sdpa documentation * mamba dt_proj quant fix now works for LoRA as well * override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers * add jamba to tokenization auto * fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24) * simple PR fixes * remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer * remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (https://github.com/huggingface/peft/pull/1530) * Add copied comment on JambaMLP (it's the same as MixtralMLP) * remove padding_mask warnings. It's not supported anymore * fix docstring. Float instead of int * A few more minor PR fixes * (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass * Return None attention weights from mamba layers. Append to all attentions only if not None. * remove some leftover jamba archive lists * Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel * no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers * Add Jamba paper on READMEs * (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes) * Add copied from comment * remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms * clearer docstring for _convert_to_standard_cache * style fixes * Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs * rename test so it still overrides what its meant to override * draft * oups * nit * remove more complexe logic * fix names used in config * fix fix fix * style * fix some more failing tests * generate did not init the cache 🙃 * more small nits * typo * config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes * fix init of pkv with torch.tensor() * empty tensor * fix some init issues * stupid changes required by generate because it does not even support it's own DynamicCache class * more fixes * fix general assisted gen cache_position bug * tests passing * Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py * fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache * no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore * fix docstrings and typehints for past_key_values * style fixes * fix docs * change typehint due to copy from Mixtral * forgot import * import order * Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward) * Add integration test with tiny tandom Jamba model on hub * fix flash attention cache shapes * bring back forgotten hidden states * rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model * align integration test after modeling fixes * bugfix - mamba can use precomputed states only of forward pass is on a single token * bugfix - mamba can use precomputed states only if they match the batch size * typo * remove making _prepare_4d_causal_attention_mask a leaf function * stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -1050,7 +1050,7 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
|
||||
self.skipTest("TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
@@ -1098,6 +1098,7 @@ class GenerationTesterMixin:
|
||||
"transo_xl",
|
||||
"xlnet",
|
||||
"cpm",
|
||||
"jamba",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
@@ -1735,11 +1736,12 @@ class GenerationTesterMixin:
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Past Key Value States -- two notes here:
|
||||
# Past Key Value States -- a few notes here:
|
||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
|
||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
||||
# complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user