Modify device_map behavior when loading a model using from_pretrained (#23922)
* Modify device map behavior for 4/8 bits model * Remove device_map arg for training 4/8 bit model * Remove index Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add Exceptions * Modify comment Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix formatting * Get current device with accelerate * Revert "Get current device with accelerate" This reverts commit 46f00799103bbe15bd58762ba029aab35363c4f7. * Fix Exception * Modify quantization doc * Fix error Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -429,7 +429,9 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
return
|
||||
|
||||
# Step 1: freeze all parameters
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
|
||||
|
||||
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False # freeze the model - train adapters later
|
||||
|
||||
@@ -684,7 +684,9 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
return
|
||||
|
||||
# Step 1: freeze all parameters
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
|
||||
|
||||
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False # freeze the model - train adapters later
|
||||
|
||||
Reference in New Issue
Block a user