Add test for new model parallelism features (#17401)
This commit is contained in:
@@ -1734,6 +1734,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
same device.
|
same device.
|
||||||
|
|
||||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
|
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
|
||||||
|
max_memory (`Dict`, *optional*):
|
||||||
|
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||||
|
GPU and the available CPU RAM if unset.
|
||||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||||
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
||||||
@@ -1822,6 +1825,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
|
||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
max_memory = kwargs.pop("max_memory", None)
|
||||||
offload_folder = kwargs.pop("offload_folder", None)
|
offload_folder = kwargs.pop("offload_folder", None)
|
||||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
|
|
||||||
@@ -2119,7 +2123,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if model._no_split_modules is None:
|
if model._no_split_modules is None:
|
||||||
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
|
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
|
||||||
no_split_modules = model._no_split_modules
|
no_split_modules = model._no_split_modules
|
||||||
device_map = infer_auto_device_map(model, no_split_module_classes=no_split_modules, dtype=torch_dtype)
|
device_map = infer_auto_device_map(
|
||||||
|
model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory
|
||||||
|
)
|
||||||
|
|
||||||
if from_tf:
|
if from_tf:
|
||||||
if resolved_archive_file.endswith(".index"):
|
if resolved_archive_file.endswith(".index"):
|
||||||
|
|||||||
@@ -420,14 +420,12 @@ class T5Attention(nn.Module):
|
|||||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||||
return relative_buckets
|
return relative_buckets
|
||||||
|
|
||||||
def compute_bias(self, query_length, key_length):
|
def compute_bias(self, query_length, key_length, device=None):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
context_position = torch.arange(
|
if device is None:
|
||||||
query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
|
device = self.relative_attention_bias.weight.device
|
||||||
)[:, None]
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||||
memory_position = torch.arange(
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||||
key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
|
|
||||||
)[None, :]
|
|
||||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
relative_position_bucket = self._relative_position_bucket(
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
relative_position, # shape (query_length, key_length)
|
relative_position, # shape (query_length, key_length)
|
||||||
@@ -522,7 +520,7 @@ class T5Attention(nn.Module):
|
|||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
position_bias.requires_grad = True
|
position_bias.requires_grad = True
|
||||||
else:
|
else:
|
||||||
position_bias = self.compute_bias(real_seq_length, key_length)
|
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
||||||
|
|
||||||
# if key and values are already calculated
|
# if key and values are already calculated
|
||||||
# we want only the last query position bias
|
# we want only the last query position bias
|
||||||
|
|||||||
@@ -51,7 +51,9 @@ from transformers.testing_utils import (
|
|||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
|
require_accelerate,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_usr_bin_time,
|
require_usr_bin_time,
|
||||||
slow,
|
slow,
|
||||||
@@ -60,6 +62,7 @@ from transformers.testing_utils import (
|
|||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
is_accelerate_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
@@ -72,6 +75,10 @@ sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
|||||||
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate.utils import compute_module_sizes
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -2178,6 +2185,86 @@ class ModelTesterMixin:
|
|||||||
model.parallelize()
|
model.parallelize()
|
||||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||||
|
|
||||||
|
def check_device_map_is_respected(self, model, device_map):
|
||||||
|
for param_name, param in model.named_parameters():
|
||||||
|
# Find device in device_map
|
||||||
|
while len(param_name) > 0 and param_name not in device_map:
|
||||||
|
param_name = ".".join(param_name.split(".")[:-1])
|
||||||
|
if param_name not in device_map:
|
||||||
|
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||||
|
|
||||||
|
param_device = device_map[param_name]
|
||||||
|
if param_device in ["cpu", "disk"]:
|
||||||
|
self.assertEqual(param.device, torch.device("meta"))
|
||||||
|
else:
|
||||||
|
self.assertEqual(param.device, torch.device(param_device))
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_cpu_offload(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if config.num_hidden_layers < 5:
|
||||||
|
config.num_hidden_layers = 5
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class._no_split_modules is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config).eval()
|
||||||
|
model = model.to(torch_device)
|
||||||
|
base_output = model(**inputs_dict)
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
# We test several splits of sizes to make sure it works.
|
||||||
|
max_gpu_sizes = [int(p * model_size) for p in [0.5, 0.7, 0.9]]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.cpu().save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for max_size in max_gpu_sizes:
|
||||||
|
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||||
|
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||||
|
# Making sure part of the model will actually end up offloaded
|
||||||
|
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
|
||||||
|
|
||||||
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
new_output = new_model(**inputs_dict)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_model_parallelism(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if config.num_hidden_layers < 5:
|
||||||
|
config.num_hidden_layers = 5
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class._no_split_modules is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config).eval()
|
||||||
|
model = model.to(torch_device)
|
||||||
|
base_output = model(**inputs_dict)
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
# We test several splits of sizes to make sure it works.
|
||||||
|
max_gpu_sizes = [int(p * model_size) for p in [0.5, 0.7, 0.9]]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.cpu().save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for max_size in max_gpu_sizes:
|
||||||
|
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||||
|
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||||
|
# Making sure part of the model will actually end up offloaded
|
||||||
|
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||||
|
|
||||||
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
new_output = new_model(**inputs_dict)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||||
|
|
||||||
def test_problem_types(self):
|
def test_problem_types(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -2547,6 +2634,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
||||||
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
||||||
# sharded models
|
# sharded models
|
||||||
@@ -2559,6 +2647,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
|
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
|
||||||
|
|
||||||
@require_usr_bin_time
|
@require_usr_bin_time
|
||||||
|
@require_accelerate
|
||||||
def test_from_pretrained_low_cpu_mem_usage_measured(self):
|
def test_from_pretrained_low_cpu_mem_usage_measured(self):
|
||||||
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
|
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
|
||||||
|
|
||||||
@@ -2597,6 +2686,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
# functionality to load models directly on gpu, this test can be rewritten to use torch's
|
# functionality to load models directly on gpu, this test can be rewritten to use torch's
|
||||||
# cuda memory tracking and then we should be able to do a much more precise test.
|
# cuda memory tracking and then we should be able to do a much more precise test.
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@slow
|
@slow
|
||||||
def test_model_parallelism_gpt2(self):
|
def test_model_parallelism_gpt2(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user