Fix serialization for offloaded model (#31727)

* Fix serialization

* style

* add test
This commit is contained in:
Marc Sun
2024-07-05 08:07:07 +02:00
committed by GitHub
parent eaa5f41439
commit 8c5c180de0
2 changed files with 29 additions and 12 deletions

View File

@@ -1065,6 +1065,23 @@ class ModelUtilsTest(TestCasePlus):
# This check we did call the fake head request
mock_head.assert_called()
@require_accelerate
@mark.accelerate_tests
def test_save_model_with_device_map_cpu(self):
model_id = "hf-internal-testing/tiny-random-gpt2"
inputs = torch.tensor([[1, 2, 3]])
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
output = model(inputs)[0]
model.save_pretrained(
tmp_dir, max_shard_size="200KB"
) # model is 1.6MB, max shard size is allocated to cpu by default
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map="cpu")
saved_model_output = saved_model(inputs)[0]
self.assertTrue(torch.allclose(output, saved_model_output))
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
@@ -1083,9 +1100,9 @@ class ModelUtilsTest(TestCasePlus):
# check_models_equal requires onloaded tensors
model_id = "hf-internal-testing/tiny-random-gpt2"
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu").to(f"{torch_device}:0")
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
cpu_output = onloaded_model(inputs)[0]
output = onloaded_model(inputs)[0]
with tempfile.TemporaryDirectory() as tmp_dir:
offload_folder = os.path.join(tmp_dir, "offload")
@@ -1099,7 +1116,7 @@ class ModelUtilsTest(TestCasePlus):
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
postsaved_output = saved_model(inputs)[0]
self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4))
self.assertTrue(torch.allclose(output, presaved_output, atol=1e-4))
self.assertTrue(torch.allclose(presaved_output, postsaved_output))
@require_safetensors