[Mamba doc] Post merge updates (#29472)
* post merge update * nit * oups
This commit is contained in:
@@ -44,11 +44,8 @@ The original code can be found [here](https://github.com/state-spaces/mamba).
|
||||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
|
||||
model.config.use_cache = True
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
@@ -63,8 +60,8 @@ from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
model_id = "ArthurZ/mamba-2.8b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
@@ -77,7 +74,7 @@ training_args = TrainingArguments(
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules="all-linear",
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user