@@ -2166,11 +2166,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
offload_state_dict=False,
|
offload_state_dict=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
if device_map is not None and "disk" in device_map.values() and offload_folder is None:
|
if device_map is not None and "disk" in device_map.values():
|
||||||
raise ValueError(
|
if offload_folder is None:
|
||||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for"
|
raise ValueError(
|
||||||
" them."
|
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||||
)
|
" for them."
|
||||||
|
)
|
||||||
|
os.makedirs(offload_folder, exist_ok=True)
|
||||||
# Retrieve missing & unexpected_keys
|
# Retrieve missing & unexpected_keys
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
@@ -2344,6 +2346,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if offload_index is not None and len(offload_index) > 0:
|
if offload_index is not None and len(offload_index) > 0:
|
||||||
|
if model != model_to_load:
|
||||||
|
# We need to add the prefix of the base model
|
||||||
|
prefix = cls.base_model_prefix
|
||||||
|
for weight_name in offload_index:
|
||||||
|
shutil.move(
|
||||||
|
os.path.join(offload_folder, f"{weight_name}.dat"),
|
||||||
|
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
|
||||||
|
)
|
||||||
|
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
|
||||||
save_offload_index(offload_index, offload_folder)
|
save_offload_index(offload_index, offload_folder)
|
||||||
|
|
||||||
if offload_state_dict:
|
if offload_state_dict:
|
||||||
|
|||||||
@@ -2811,6 +2811,48 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
text_output = tokenizer.decode(output[0].tolist())
|
text_output = tokenizer.decode(output[0].tolist())
|
||||||
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_from_pretrained_disk_offload_task_model(self):
|
||||||
|
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
device_map = {
|
||||||
|
"transformer.wte": 0,
|
||||||
|
"transformer.wpe": 0,
|
||||||
|
"transformer.h.0": "cpu",
|
||||||
|
"transformer.h.1": "cpu",
|
||||||
|
"transformer.h.2": "cpu",
|
||||||
|
"transformer.h.3": "disk",
|
||||||
|
"transformer.h.4": "disk",
|
||||||
|
"transformer.ln_f": 0,
|
||||||
|
"lm_head": 0,
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
inputs = torch.tensor([[1, 2, 3]]).to(0)
|
||||||
|
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0)
|
||||||
|
outputs1 = new_model.to(0)(inputs)
|
||||||
|
|
||||||
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
|
new_model_with_offload = AutoModelForCausalLM.from_pretrained(
|
||||||
|
tmp_dir, device_map=device_map, offload_folder=offload_folder
|
||||||
|
)
|
||||||
|
outputs2 = new_model_with_offload(inputs)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
||||||
|
|
||||||
|
# With state dict temp offload
|
||||||
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
|
new_model_with_offload = AutoModelForCausalLM.from_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
device_map=device_map,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=True,
|
||||||
|
)
|
||||||
|
outputs2 = new_model_with_offload(inputs)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
||||||
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
|
|||||||
Reference in New Issue
Block a user