From 1cd9be2aeb2a3cf0d8e982717c9f1a63319f838c Mon Sep 17 00:00:00 2001 From: alexorona Date: Mon, 23 Nov 2020 11:41:23 -0800 Subject: [PATCH] gpt2 and t5 parallel modeling (#8696) * gpt2 and t5 parallel modeling * model_parallel utils update * adding missing model_parallel_utils Adds missing model_parallel_utils and reverses the changes to code in modeling_gpt2 and modeling_t5 * training_args reformat Reformatted training_args * style formatting Style formatting doc string length on training_args and model_parallel_utils * style changes make style && make quality for training_args and model_parallel_utils. * adding tests * minor change in trainer reverts loss calculation * Update training_args.py * Update training_args.py added back docstring language for adam_beta1 and adam_beta2 * Update trainer.py * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix style & rebase Co-authored-by: Lysandre Debut Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: LysandreJik --- src/transformers/models/gpt2/modeling_gpt2.py | 122 +++++++++++ src/transformers/models/t5/modeling_t5.py | 189 +++++++++++++++++- src/transformers/trainer.py | 16 +- src/transformers/training_args.py | 24 ++- .../utils/model_parallel_utils.py | 40 ++++ tests/test_modeling_common.py | 92 +++++++++ tests/test_modeling_gpt2.py | 5 + tests/test_modeling_t5.py | 12 ++ 8 files changed, 492 insertions(+), 8 deletions(-) create mode 100644 src/transformers/utils/model_parallel_utils.py diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 759c275b74..39b40a1e54 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -44,6 +44,7 @@ from ...modeling_utils import ( prune_conv1d_layer, ) from ...utils import logging +from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_gpt2 import GPT2Config @@ -474,6 +475,46 @@ GPT2_INPUTS_DOCSTRING = r""" return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ +PARALLELIZE_DOCSTRING = r""" + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + + Example:: + Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained('gpt2-xl') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained('gpt2-large') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], + + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" @add_start_docstrings( @@ -491,6 +532,42 @@ class GPT2Model(GPT2PreTrainedModel): self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() def get_input_embeddings(self): return self.wte @@ -616,6 +693,18 @@ class GPT2Model(GPT2PreTrainedModel): all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = layer_past.to(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) @@ -658,6 +747,12 @@ class GPT2Model(GPT2PreTrainedModel): if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3],) + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(*output_shape) @@ -694,6 +789,28 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.init_weights() + self.model_parallel = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + def get_output_embeddings(self): return self.lm_head @@ -774,6 +891,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ) hidden_states = transformer_outputs[0] + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + lm_logits = self.lm_head(hidden_states) loss = None diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 915c9548c1..7b28d2590c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -40,6 +40,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging +from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_t5 import T5Config @@ -177,6 +178,47 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): # - torch.nn.Module for the layers and # - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module) #################################################### +PARALLELIZE_DOCSTRING = r""" + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" class T5LayerNorm(nn.Module): @@ -729,6 +771,42 @@ class T5Stack(T5PreTrainedModel): self.dropout = nn.Dropout(config.dropout_rate) self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() def get_input_embeddings(self): return self.embed_tokens @@ -753,7 +831,10 @@ class T5Stack(T5PreTrainedModel): output_hidden_states=None, return_dict=None, ): - + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -821,6 +902,20 @@ class T5Stack(T5PreTrainedModel): hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -855,6 +950,12 @@ class T5Stack(T5PreTrainedModel): if self.is_decoder: all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1008,6 +1109,32 @@ class T5Model(T5PreTrainedModel): self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + def get_input_embeddings(self): return self.shared @@ -1086,6 +1213,18 @@ class T5Model(T5PreTrainedModel): ) hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) # Decode decoder_outputs = self.decoder( @@ -1147,6 +1286,34 @@ class T5ForConditionalGeneration(T5PreTrainedModel): self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + def get_input_embeddings(self): return self.shared @@ -1231,6 +1398,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -1244,6 +1414,17 @@ class T5ForConditionalGeneration(T5PreTrainedModel): if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1261,6 +1442,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): sequence_output = decoder_outputs[0] + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 64c363afb5..27372a9b50 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -241,7 +241,11 @@ class Trainer: self.hp_name = None if model is None and model_init is not None: model = self.call_model_init() - self.model = model.to(args.device) if model is not None else None + # Model parallel + if not self.args.model_parallel: + self.model = model.to(args.device) if model is not None else None + else: + self.model = model if model is not None else None default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -578,7 +582,8 @@ class Trainer: model = self.call_model_init(trial) - self.model = model.to(self.args.device) + if not self.args.model_parallel: + self.model = model.to(self.args.device) # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None @@ -625,7 +630,7 @@ class Trainer: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) - if self.args.n_gpu > 1: + if self.args.n_gpu > 1 and not self.args.model_parallel: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) @@ -805,7 +810,8 @@ class Trainer: ) if isinstance(model, PreTrainedModel): self.model = model.from_pretrained(self.state.best_model_checkpoint) - self.model = self.model.to(self.args.device) + if not self.args.model_parallel: + self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) self.model.load_state_dict(state_dict) @@ -1323,7 +1329,7 @@ class Trainer: model = self.model # multi-gpu eval - if self.args.n_gpu > 1: + if self.args.n_gpu > 1 and not self.args.model_parallel: model = torch.nn.DataParallel(model) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 416037ac43..ff22be4db4 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -40,6 +40,9 @@ class TrainingArguments: Using :class:`~transformers.HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command line. + + + Parameters: output_dir (:obj:`str`): The output directory where the model predictions and checkpoints will be written. @@ -201,6 +204,15 @@ class TrainingArguments: do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) + model_parallel: bool = field( + default=False, + metadata={ + "help": ( + "If there are more than one devices, whether to use model parallelism to distribute the " + "model's modules across devices." + ) + }, + ) evaluation_strategy: EvaluationStrategy = field( default="no", metadata={"help": "Run evaluation during training at each logging step."}, @@ -366,7 +378,11 @@ class TrainingArguments: "version. Using `--per_device_train_batch_size` is preferred." ) per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size - return per_device_batch_size * max(1, self.n_gpu) + if not self.model_parallel: + train_batch_size = per_device_batch_size * max(1, self.n_gpu) + else: + train_batch_size = per_device_batch_size + return train_batch_size @property def eval_batch_size(self) -> int: @@ -379,7 +395,11 @@ class TrainingArguments: "version. Using `--per_device_eval_batch_size` is preferred." ) per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size - return per_device_batch_size * max(1, self.n_gpu) + if not self.model_parallel: + eval_batch_size = per_device_batch_size * max(1, self.n_gpu) + else: + eval_batch_size = per_device_batch_size + return eval_batch_size @cached_property @torch_required diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py new file mode 100644 index 0000000000..6c3a6dcc1d --- /dev/null +++ b/src/transformers/utils/model_parallel_utils.py @@ -0,0 +1,40 @@ +# coding=utf-8 +from math import ceil + + +def assert_device_map(device_map, num_blocks): + blocks = list(range(0, num_blocks)) + + device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] + + # Duplicate check + duplicate_blocks = [] + for i in device_map_blocks: + if device_map_blocks.count(i) > 1 and i not in duplicate_blocks: + duplicate_blocks.append(i) + # Missing blocks + missing_blocks = [i for i in blocks if i not in device_map_blocks] + extra_blocks = [i for i in device_map_blocks if i not in blocks] + + assert len(duplicate_blocks) == 0, ( + "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These " + "attention blocks were specified more than once: " + str(duplicate_blocks) + ) + assert len(missing_blocks) == 0, ( + "There are attention blocks for this model that are not specified in the device_map. Add these attention " + "blocks to a device on the device_map: " + str(missing_blocks) + ) + assert ( + len(extra_blocks) == 0 + ), "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str( + extra_blocks + ) + + +def get_device_map(n_layers, devices): + """Returns a dictionary of layers distributed evenly across all devices.""" + layers = list(range(n_layers)) + n_blocks = int(ceil(n_layers / len(devices))) + layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) + + return dict(zip(devices, layers_list)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b72031d2f5..f4cc7d6958 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -68,6 +68,7 @@ class ModelTesterMixin: test_resize_embeddings = True test_head_masking = True test_missing_keys = True + test_model_parallel = False is_encoder_decoder = False def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): @@ -953,6 +954,97 @@ class ModelTesterMixin: with torch.no_grad(): _ = model(**self._prepare_for_class(inputs_dict, model_class)) + @require_torch_multi_gpu + def test_model_parallelization(self): + if not self.test_model_parallel: + pass + + import subprocess + + def get_current_gpu_memory_use(): + run_process = subprocess.Popen( + "nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader", shell=True, stdout=subprocess.PIPE + ) + + memory_usage = run_process.stdout.read().decode("utf-8").strip() + per_device_memory = [int(memory) for memory in memory_usage.split("\n")] + return per_device_memory + + # Needs a large model to see the difference. + config = self.model_tester.get_large_model_config() + + for model_class in self.all_parallelizable_model_classes: + torch.cuda.empty_cache() + + # Retrieve initial memory usage (should be close to 0) + initial_memory = get_current_gpu_memory_use() + + # Put model on device + model = model_class(config.from_pretrained("gpt2")) + model.to("cuda:0") + + # Retrieve the memory after the model is put on the device + memory_after_model_load = get_current_gpu_memory_use() + + del model + torch.cuda.empty_cache() + + # The memory use on that device should be higher than it was initially. + self.assertGreater(memory_after_model_load[0], initial_memory[0]) + + # Spread model layers over multiple devices + model = model_class(config.from_pretrained("gpt2")) + model.parallelize() + memory_after_parallelization = get_current_gpu_memory_use() + + # Assert that the memory use on all devices is higher than it was when loaded only on CPU + for n in range(torch.cuda.device_count()): + self.assertGreater(memory_after_parallelization[n], initial_memory[n]) + + # Assert that the memory use of the first device is lower than it was when the entire model was loaded on it + self.assertLess(memory_after_parallelization[0], memory_after_model_load[0]) + + # Assert that the memory use of the second device is higher than it was when the entire model was loaded + # on the other device. + self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1]) + + del model + torch.cuda.empty_cache() + + @require_torch_multi_gpu + def test_model_parallel_equal_results(self): + if not self.test_model_parallel: + pass + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_parallelizable_model_classes: + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + + model = model_class(config) + output = model(**inputs_dict) + + model.parallelize() + + def cast_to_gpu(dictionary): + output = {} + for k, v in dictionary.items(): + if isinstance(v, torch.Tensor): + output[k] = v.to("cuda:0") + else: + output[k] = v + + return output + + parallel_output = model(**cast_to_gpu(inputs_dict)) + + for value, parallel_value in zip(output, parallel_output): + if isinstance(value, torch.Tensor): + self.assertTrue(torch.allclose(value, parallel_value.to("cpu"), atol=1e-7)) + elif isinstance(value, (Tuple, List)): + for value_, parallel_value_ in zip(value, parallel_value): + self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7)) + global_rng = random.Random() diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 900a989a10..298a506a5b 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -92,6 +92,9 @@ class GPT2ModelTester: self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 + def get_large_model_config(self): + return GPT2Config.from_pretrained("gpt2") + def prepare_config_and_inputs(self, gradient_checkpointing=False): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -389,7 +392,9 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () + all_parallelizable_model_classes = (GPT2LMHeadModel,) if is_torch_available() else () test_missing_keys = False + test_model_parallel = True # special case for DoubleHeads model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 90573d5a78..1e0292db23 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -85,6 +85,9 @@ class T5ModelTester: self.scope = None self.decoder_layers = decoder_layers + def get_large_model_config(self): + return T5Config.from_pretrained("t5-base") + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) @@ -470,9 +473,18 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () + all_parallelizable_model_classes = ( + ( + T5Model, + T5ForConditionalGeneration, + ) + if is_torch_available() + else () + ) test_pruning = False test_torchscript = True test_resize_embeddings = False + test_model_parallel = True is_encoder_decoder = True def setUp(self):