From 31484afbed45ee589f8e4e247b10188a09399734 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 25 May 2022 10:51:27 -0400 Subject: [PATCH] Add test for new model parallelism features (#17401) --- src/transformers/modeling_utils.py | 8 +- src/transformers/models/t5/modeling_t5.py | 14 ++-- tests/test_modeling_common.py | 90 +++++++++++++++++++++++ 3 files changed, 103 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8ade42459d..3327539508 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1734,6 +1734,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix same device. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*, defaults to `False`): @@ -1822,6 +1825,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix torch_dtype = kwargs.pop("torch_dtype", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) @@ -2119,7 +2123,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if model._no_split_modules is None: raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.") no_split_modules = model._no_split_modules - device_map = infer_auto_device_map(model, no_split_module_classes=no_split_modules, dtype=torch_dtype) + device_map = infer_auto_device_map( + model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory + ) if from_tf: if resolved_archive_file.endswith(".index"): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 41ff0ecf04..92b64ea7fb 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -420,14 +420,12 @@ class T5Attention(nn.Module): relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length): + def compute_bias(self, query_length, key_length, device=None): """Compute binned relative position bias""" - context_position = torch.arange( - query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device - )[:, None] - memory_position = torch.arange( - key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device - )[None, :] + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) @@ -522,7 +520,7 @@ class T5Attention(nn.Module): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length) + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4540ebe962..9a7777af4a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -51,7 +51,9 @@ from transformers.testing_utils import ( is_pt_flax_cross_test, is_pt_tf_cross_test, is_staging_test, + require_accelerate, require_torch, + require_torch_gpu, require_torch_multi_gpu, require_usr_bin_time, slow, @@ -60,6 +62,7 @@ from transformers.testing_utils import ( from transformers.utils import ( WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + is_accelerate_available, is_flax_available, is_tf_available, is_torch_fx_available, @@ -72,6 +75,10 @@ sys.path.append(str(Path(__file__).parent.parent / "utils")) from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402 +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + + if is_torch_available(): import torch from torch import nn @@ -2178,6 +2185,86 @@ class ModelTesterMixin: model.parallelize() model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) + def check_device_map_is_respected(self, model, device_map): + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + self.assertEqual(param.device, torch.device("meta")) + else: + self.assertEqual(param.device, torch.device(param_device)) + + @require_accelerate + @require_torch_gpu + def test_cpu_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.num_hidden_layers < 5: + config.num_hidden_layers = 5 + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in [0.5, 0.7, 0.9]] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_multi_gpu + def test_model_parallelism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.num_hidden_layers < 5: + config.num_hidden_layers = 5 + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in [0.5, 0.7, 0.9]] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + def test_problem_types(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2547,6 +2634,7 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(model.parameters(), ref_model.parameters()): self.assertTrue(torch.allclose(p1, p2)) + @require_accelerate def test_from_pretrained_low_cpu_mem_usage_functional(self): # test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and # sharded models @@ -2559,6 +2647,7 @@ class ModelUtilsTest(TestCasePlus): _ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True) @require_usr_bin_time + @require_accelerate def test_from_pretrained_low_cpu_mem_usage_measured(self): # test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default @@ -2597,6 +2686,7 @@ class ModelUtilsTest(TestCasePlus): # functionality to load models directly on gpu, this test can be rewritten to use torch's # cuda memory tracking and then we should be able to do a much more precise test. + @require_accelerate @require_torch_multi_gpu @slow def test_model_parallelism_gpt2(self):