[Mamba doc] Post merge updates (#29472)

* post merge update

* nit

* oups
This commit is contained in:
Arthur
2024-03-11 19:46:24 +11:00
committed by GitHub
parent 0290ec19c9
commit 4f27ee936a
3 changed files with 14 additions and 17 deletions

View File

@@ -44,11 +44,8 @@ The original code can be found [here](https://github.com/state-spaces/mamba).
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
model.config.use_cache = True
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
@@ -63,8 +60,8 @@ from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "ArthurZ/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
@@ -77,7 +74,7 @@ training_args = TrainingArguments(
)
lora_config = LoraConfig(
r=8,
target_modules="all-linear",
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)