Add test for new model parallelism features (#17401)

This commit is contained in:
Sylvain Gugger
2022-05-25 10:51:27 -04:00
committed by GitHub
parent 56b35ce3eb
commit 31484afbed
3 changed files with 103 additions and 9 deletions

View File

@@ -51,7 +51,9 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pt_tf_cross_test,
is_staging_test,
require_accelerate,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_usr_bin_time,
slow,
@@ -60,6 +62,7 @@ from transformers.testing_utils import (
from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_flax_available,
is_tf_available,
is_torch_fx_available,
@@ -72,6 +75,10 @@ sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
if is_accelerate_available():
from accelerate.utils import compute_module_sizes
if is_torch_available():
import torch
from torch import nn
@@ -2178,6 +2185,86 @@ class ModelTesterMixin:
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta"))
else:
self.assertEqual(param.device, torch.device(param_device))
@require_accelerate
@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5:
config.num_hidden_layers = 5
for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in [0.5, 0.7, 0.9]]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@require_accelerate
@require_torch_multi_gpu
def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5:
config.num_hidden_layers = 5
for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in [0.5, 0.7, 0.9]]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -2547,6 +2634,7 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_accelerate
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
@@ -2559,6 +2647,7 @@ class ModelUtilsTest(TestCasePlus):
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
@require_usr_bin_time
@require_accelerate
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
@@ -2597,6 +2686,7 @@ class ModelUtilsTest(TestCasePlus):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
@require_accelerate
@require_torch_multi_gpu
@slow
def test_model_parallelism_gpt2(self):