@@ -39,11 +39,11 @@ The original code can be found [here](https://github.com/state-spaces/mamba).
|
|||||||
|
|
||||||
### A simple generation example:
|
### A simple generation example:
|
||||||
```python
|
```python
|
||||||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
|
||||||
import torch
|
import torch
|
||||||
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
|
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
|
||||||
model = MambaForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
|
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
|
||||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||||
|
|
||||||
out = model.generate(input_ids, max_new_tokens=10)
|
out = model.generate(input_ids, max_new_tokens=10)
|
||||||
|
|||||||
Reference in New Issue
Block a user