@@ -54,8 +54,8 @@ For each model type, there is a separate class for each machine learning framewo
|
|||||||
from transformers import AutoModelForCausalLM, MistralForCausalLM
|
from transformers import AutoModelForCausalLM, MistralForCausalLM
|
||||||
|
|
||||||
# load with AutoClass or model-specific class
|
# load with AutoClass or model-specific class
|
||||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
|
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
|
||||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
|
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
|
||||||
```
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
@@ -272,6 +272,7 @@ Explicitly set the [torch_dtype](https://pytorch.org/docs/stable/tensor_attribut
|
|||||||
<hfoption id="specific dtype">
|
<hfoption id="specific dtype">
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
import torch
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
|
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
|
||||||
|
|||||||
Reference in New Issue
Block a user