diff --git a/docs/source/en/main_classes/quantization.mdx b/docs/source/en/main_classes/quantization.mdx index 6ab6ec9dfa..37877c9d02 100644 --- a/docs/source/en/main_classes/quantization.mdx +++ b/docs/source/en/main_classes/quantization.mdx @@ -33,7 +33,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "bigscience/bloom-1b7" tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id, device_map == "auto", load_in_8bit=True) +model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True) ``` Then, use your model as you would usually use a [`PreTrainedModel`].