[Mamba doc] Post merge updates (#29472)
* post merge update * nit * oups
This commit is contained in:
@@ -44,11 +44,8 @@ The original code can be found [here](https://github.com/state-spaces/mamba).
|
||||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
|
||||
model.config.use_cache = True
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
@@ -63,8 +60,8 @@ from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
model_id = "ArthurZ/mamba-2.8b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
@@ -77,7 +74,7 @@ training_args = TrainingArguments(
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules="all-linear",
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m"
|
||||
_CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
|
||||
_CONFIG_FOR_DOC = "MambaConfig"
|
||||
|
||||
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba
|
||||
@@ -605,7 +605,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
||||
def _update_model_kwargs_for_generation(
|
||||
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
model_kwargs["cache_params"] = outputs["cache_params"]
|
||||
model_kwargs["cache_params"] = outputs.get("cache_params", None)
|
||||
return model_kwargs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
||||
@@ -406,15 +406,15 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
@require_torch
|
||||
class MambaIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "ArthurZ/mamba-2.8b"
|
||||
self.model_id = "state-spaces/mamba-2.8b-hf"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
|
||||
@parameterized.expand([(torch_device,), ("cpu",)])
|
||||
def test_simple_generate(self, device):
|
||||
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
|
||||
model.to(device)
|
||||
model.config.use_cache = True
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
|
||||
@@ -444,7 +444,7 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
expected_output = "Hello my name is John and I am a newbie to the world"
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(device)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16).to(device)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
@@ -457,7 +457,7 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
expected_output = "Hello my name is\n\nI am a\n\nI am a"
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(device)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-790m-hf", torch_dtype=torch.float16).to(device)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
@@ -470,7 +470,7 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a"
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(device)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16).to(device)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
@@ -483,7 +483,7 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
expected_output = "Hello my name is John and I am a new member of this forum. I am a retired Marine and I am a member of the Marine Corps League. I am a"
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(device)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.float16).to(device)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=30)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
|
||||
Reference in New Issue
Block a user