gpt2 and t5 parallel modeling (#8696)
* gpt2 and t5 parallel modeling * model_parallel utils update * adding missing model_parallel_utils Adds missing model_parallel_utils and reverses the changes to code in modeling_gpt2 and modeling_t5 * training_args reformat Reformatted training_args * style formatting Style formatting doc string length on training_args and model_parallel_utils * style changes make style && make quality for training_args and model_parallel_utils. * adding tests * minor change in trainer reverts loss calculation * Update training_args.py * Update training_args.py added back docstring language for adam_beta1 and adam_beta2 * Update trainer.py * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix style & rebase Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -44,6 +44,7 @@ from ...modeling_utils import (
|
||||
prune_conv1d_layer,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
|
||||
|
||||
@@ -474,6 +475,46 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
PARALLELIZE_DOCSTRING = r"""
|
||||
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
||||
it will evenly distribute blocks across all devices.
|
||||
|
||||
Args:
|
||||
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
|
||||
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
||||
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
||||
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
||||
following number of attention modules:
|
||||
|
||||
- gpt2: 12
|
||||
- gpt2-medium: 24
|
||||
- gpt2-large: 36
|
||||
- gpt2-xl: 48
|
||||
|
||||
Example::
|
||||
Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
|
||||
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
|
||||
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
||||
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
||||
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
|
||||
model.parallelize(device_map)
|
||||
"""
|
||||
DEPARALLELIZE_DOCSTRING = r"""
|
||||
Moves the model to cpu from a model parallel state.
|
||||
|
||||
Example::
|
||||
On a 4 GPU machine with gpt2-large:
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
|
||||
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
|
||||
|
||||
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
||||
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
||||
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
|
||||
model.parallelize(device_map) # Splits the model across several devices
|
||||
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -491,6 +532,42 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.init_weights()
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
# Check validity of device_map
|
||||
self.device_map = (
|
||||
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.h))
|
||||
self.model_parallel = True
|
||||
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
||||
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
||||
self.wte = self.wte.to(self.first_device)
|
||||
self.wpe = self.wpe.to(self.first_device)
|
||||
# Load onto devices
|
||||
for k, v in self.device_map.items():
|
||||
for block in v:
|
||||
cuda_device = "cuda:" + str(k)
|
||||
self.h[block] = self.h[block].to(cuda_device)
|
||||
# ln_f to last
|
||||
self.ln_f = self.ln_f.to(self.last_device)
|
||||
|
||||
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
self.first_device = "cpu"
|
||||
self.last_device = "cpu"
|
||||
self.wte = self.wte.to("cpu")
|
||||
self.wpe = self.wpe.to("cpu")
|
||||
for index in range(len(self.h)):
|
||||
self.h[index] = self.h[index].to("cpu")
|
||||
self.ln_f = self.ln_f.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
@@ -616,6 +693,18 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = layer_past.to(hidden_states.device)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
@@ -658,6 +747,12 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
@@ -694,6 +789,28 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
self.model_parallel = False
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.transformer.h))
|
||||
self.transformer.parallelize(self.device_map)
|
||||
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
||||
self.model_parallel = True
|
||||
|
||||
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.transformer.deparallelize()
|
||||
self.transformer = self.transformer.to("cpu")
|
||||
self.lm_head = self.lm_head.to("cpu")
|
||||
self.model_parallel = False
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@@ -774,6 +891,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
|
||||
@@ -40,6 +40,7 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_t5 import T5Config
|
||||
|
||||
|
||||
@@ -177,6 +178,47 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
# - torch.nn.Module for the layers and
|
||||
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
|
||||
####################################################
|
||||
PARALLELIZE_DOCSTRING = r"""
|
||||
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
||||
it will evenly distribute blocks across all devices.
|
||||
|
||||
Args:
|
||||
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
|
||||
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
||||
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
||||
have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
|
||||
following number of attention modules:
|
||||
|
||||
- t5-small: 6
|
||||
- t5-base: 12
|
||||
- t5-large: 24
|
||||
- t5-3b: 24
|
||||
- t5-11b: 24
|
||||
|
||||
Example::
|
||||
Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
|
||||
model = T5ForConditionalGeneration.from_pretrained('t5-3b')
|
||||
device_map = {0: [0, 1, 2],
|
||||
|
||||
1: [3, 4, 5, 6, 7, 8, 9],
|
||||
2: [10, 11, 12, 13, 14, 15, 16],
|
||||
3: [17, 18, 19, 20, 21, 22, 23]}
|
||||
model.parallelize(device_map)
|
||||
"""
|
||||
DEPARALLELIZE_DOCSTRING = r"""
|
||||
Moves the model to cpu from a model parallel state.
|
||||
|
||||
Example::
|
||||
On a 4 GPU machine with t5-3b:
|
||||
model = T5ForConditionalGeneration.from_pretrained('t5-3b')
|
||||
device_map = {0: [0, 1, 2],
|
||||
|
||||
1: [3, 4, 5, 6, 7, 8, 9],
|
||||
2: [10, 11, 12, 13, 14, 15, 16],
|
||||
3: [17, 18, 19, 20, 21, 22, 23]}
|
||||
model.parallelize(device_map) # Splits the model across several devices
|
||||
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
||||
"""
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
@@ -729,6 +771,42 @@ class T5Stack(T5PreTrainedModel):
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
self.init_weights()
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
# Check validity of device_map
|
||||
self.device_map = (
|
||||
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.block))
|
||||
self.model_parallel = True
|
||||
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
||||
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
||||
# Load onto devices
|
||||
for k, v in self.device_map.items():
|
||||
for layer in v:
|
||||
cuda_device = "cuda:" + str(k)
|
||||
self.block[layer] = self.block[layer].to(cuda_device)
|
||||
|
||||
# Set embed_tokens to first layer
|
||||
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
||||
# Set final layer norm to last device
|
||||
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
self.first_device = "cpu"
|
||||
self.last_device = "cpu"
|
||||
for i in range(len(self.block)):
|
||||
self.block[i] = self.block[i].to("cpu")
|
||||
self.embed_tokens = self.embed_tokens.to("cpu")
|
||||
self.final_layer_norm = self.final_layer_norm.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
@@ -753,7 +831,10 @@ class T5Stack(T5PreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.first_device)
|
||||
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -821,6 +902,20 @@ class T5Stack(T5PreTrainedModel):
|
||||
hidden_states = self.dropout(inputs_embeds)
|
||||
|
||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if position_bias is not None:
|
||||
position_bias = position_bias.to(hidden_states.device)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
|
||||
if encoder_extended_attention_mask is not None:
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
|
||||
if encoder_decoder_position_bias is not None:
|
||||
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@@ -855,6 +950,12 @@ class T5Stack(T5PreTrainedModel):
|
||||
if self.is_decoder:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
@@ -1008,6 +1109,32 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.encoder.block))
|
||||
self.encoder.parallelize(self.device_map)
|
||||
self.decoder.parallelize(self.device_map)
|
||||
self.model_parallel = True
|
||||
|
||||
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.encoder.deparallelize()
|
||||
self.decoder.deparallelize()
|
||||
self.encoder = self.encoder.to("cpu")
|
||||
self.decoder = self.decoder.to("cpu")
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
@@ -1086,6 +1213,18 @@ class T5Model(T5PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.decoder.first_device)
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.decoder.first_device)
|
||||
hidden_states = hidden_states.to(self.decoder.first_device)
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(self.decoder.first_device)
|
||||
if decoder_attention_mask is not None:
|
||||
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
@@ -1147,6 +1286,34 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.encoder.block))
|
||||
self.encoder.parallelize(self.device_map)
|
||||
self.decoder.parallelize(self.device_map)
|
||||
self.lm_head = self.lm_head.to(self.decoder.first_device)
|
||||
self.model_parallel = True
|
||||
|
||||
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.encoder.deparallelize()
|
||||
self.decoder.deparallelize()
|
||||
self.encoder = self.encoder.to("cpu")
|
||||
self.decoder = self.decoder.to("cpu")
|
||||
self.lm_head = self.lm_head.to("cpu")
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
@@ -1231,6 +1398,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.decoder.first_device)
|
||||
|
||||
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
# get decoder inputs from shifting lm labels to the right
|
||||
decoder_input_ids = self._shift_right(labels)
|
||||
@@ -1244,6 +1414,17 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
if decoder_inputs_embeds is not None:
|
||||
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
||||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.decoder.first_device)
|
||||
hidden_states = hidden_states.to(self.decoder.first_device)
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(self.decoder.first_device)
|
||||
if decoder_attention_mask is not None:
|
||||
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
@@ -1261,6 +1442,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.encoder.first_device)
|
||||
self.lm_head = self.lm_head.to(self.encoder.first_device)
|
||||
sequence_output = sequence_output.to(self.lm_head.weight.device)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
|
||||
@@ -241,7 +241,11 @@ class Trainer:
|
||||
self.hp_name = None
|
||||
if model is None and model_init is not None:
|
||||
model = self.call_model_init()
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
# Model parallel
|
||||
if not self.args.model_parallel:
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
else:
|
||||
self.model = model if model is not None else None
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
@@ -578,7 +582,8 @@ class Trainer:
|
||||
|
||||
model = self.call_model_init(trial)
|
||||
|
||||
self.model = model.to(self.args.device)
|
||||
if not self.args.model_parallel:
|
||||
self.model = model.to(self.args.device)
|
||||
|
||||
# Reinitializes optimizer and scheduler
|
||||
self.optimizer, self.lr_scheduler = None, None
|
||||
@@ -625,7 +630,7 @@ class Trainer:
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# Multi-gpu training (should be after apex fp16 initialization)
|
||||
if self.args.n_gpu > 1:
|
||||
if self.args.n_gpu > 1 and not self.args.model_parallel:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
@@ -805,7 +810,8 @@ class Trainer:
|
||||
)
|
||||
if isinstance(model, PreTrainedModel):
|
||||
self.model = model.from_pretrained(self.state.best_model_checkpoint)
|
||||
self.model = self.model.to(self.args.device)
|
||||
if not self.args.model_parallel:
|
||||
self.model = self.model.to(self.args.device)
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
@@ -1323,7 +1329,7 @@ class Trainer:
|
||||
|
||||
model = self.model
|
||||
# multi-gpu eval
|
||||
if self.args.n_gpu > 1:
|
||||
if self.args.n_gpu > 1 and not self.args.model_parallel:
|
||||
model = torch.nn.DataParallel(model)
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
|
||||
@@ -40,6 +40,9 @@ class TrainingArguments:
|
||||
Using :class:`~transformers.HfArgumentParser` we can turn this class into argparse arguments to be able to specify
|
||||
them on the command line.
|
||||
|
||||
|
||||
|
||||
|
||||
Parameters:
|
||||
output_dir (:obj:`str`):
|
||||
The output directory where the model predictions and checkpoints will be written.
|
||||
@@ -201,6 +204,15 @@ class TrainingArguments:
|
||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
|
||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||
model_parallel: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"If there are more than one devices, whether to use model parallelism to distribute the "
|
||||
"model's modules across devices."
|
||||
)
|
||||
},
|
||||
)
|
||||
evaluation_strategy: EvaluationStrategy = field(
|
||||
default="no",
|
||||
metadata={"help": "Run evaluation during training at each logging step."},
|
||||
@@ -366,7 +378,11 @@ class TrainingArguments:
|
||||
"version. Using `--per_device_train_batch_size` is preferred."
|
||||
)
|
||||
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
|
||||
return per_device_batch_size * max(1, self.n_gpu)
|
||||
if not self.model_parallel:
|
||||
train_batch_size = per_device_batch_size * max(1, self.n_gpu)
|
||||
else:
|
||||
train_batch_size = per_device_batch_size
|
||||
return train_batch_size
|
||||
|
||||
@property
|
||||
def eval_batch_size(self) -> int:
|
||||
@@ -379,7 +395,11 @@ class TrainingArguments:
|
||||
"version. Using `--per_device_eval_batch_size` is preferred."
|
||||
)
|
||||
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
|
||||
return per_device_batch_size * max(1, self.n_gpu)
|
||||
if not self.model_parallel:
|
||||
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
|
||||
else:
|
||||
eval_batch_size = per_device_batch_size
|
||||
return eval_batch_size
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
|
||||
40
src/transformers/utils/model_parallel_utils.py
Normal file
40
src/transformers/utils/model_parallel_utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
from math import ceil
|
||||
|
||||
|
||||
def assert_device_map(device_map, num_blocks):
|
||||
blocks = list(range(0, num_blocks))
|
||||
|
||||
device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist]
|
||||
|
||||
# Duplicate check
|
||||
duplicate_blocks = []
|
||||
for i in device_map_blocks:
|
||||
if device_map_blocks.count(i) > 1 and i not in duplicate_blocks:
|
||||
duplicate_blocks.append(i)
|
||||
# Missing blocks
|
||||
missing_blocks = [i for i in blocks if i not in device_map_blocks]
|
||||
extra_blocks = [i for i in device_map_blocks if i not in blocks]
|
||||
|
||||
assert len(duplicate_blocks) == 0, (
|
||||
"Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These "
|
||||
"attention blocks were specified more than once: " + str(duplicate_blocks)
|
||||
)
|
||||
assert len(missing_blocks) == 0, (
|
||||
"There are attention blocks for this model that are not specified in the device_map. Add these attention "
|
||||
"blocks to a device on the device_map: " + str(missing_blocks)
|
||||
)
|
||||
assert (
|
||||
len(extra_blocks) == 0
|
||||
), "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str(
|
||||
extra_blocks
|
||||
)
|
||||
|
||||
|
||||
def get_device_map(n_layers, devices):
|
||||
"""Returns a dictionary of layers distributed evenly across all devices."""
|
||||
layers = list(range(n_layers))
|
||||
n_blocks = int(ceil(n_layers / len(devices)))
|
||||
layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks))
|
||||
|
||||
return dict(zip(devices, layers_list))
|
||||
@@ -68,6 +68,7 @@ class ModelTesterMixin:
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
@@ -953,6 +954,97 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallelization(self):
|
||||
if not self.test_model_parallel:
|
||||
pass
|
||||
|
||||
import subprocess
|
||||
|
||||
def get_current_gpu_memory_use():
|
||||
run_process = subprocess.Popen(
|
||||
"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader", shell=True, stdout=subprocess.PIPE
|
||||
)
|
||||
|
||||
memory_usage = run_process.stdout.read().decode("utf-8").strip()
|
||||
per_device_memory = [int(memory) for memory in memory_usage.split("\n")]
|
||||
return per_device_memory
|
||||
|
||||
# Needs a large model to see the difference.
|
||||
config = self.model_tester.get_large_model_config()
|
||||
|
||||
for model_class in self.all_parallelizable_model_classes:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Retrieve initial memory usage (should be close to 0)
|
||||
initial_memory = get_current_gpu_memory_use()
|
||||
|
||||
# Put model on device
|
||||
model = model_class(config.from_pretrained("gpt2"))
|
||||
model.to("cuda:0")
|
||||
|
||||
# Retrieve the memory after the model is put on the device
|
||||
memory_after_model_load = get_current_gpu_memory_use()
|
||||
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# The memory use on that device should be higher than it was initially.
|
||||
self.assertGreater(memory_after_model_load[0], initial_memory[0])
|
||||
|
||||
# Spread model layers over multiple devices
|
||||
model = model_class(config.from_pretrained("gpt2"))
|
||||
model.parallelize()
|
||||
memory_after_parallelization = get_current_gpu_memory_use()
|
||||
|
||||
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
|
||||
for n in range(torch.cuda.device_count()):
|
||||
self.assertGreater(memory_after_parallelization[n], initial_memory[n])
|
||||
|
||||
# Assert that the memory use of the first device is lower than it was when the entire model was loaded on it
|
||||
self.assertLess(memory_after_parallelization[0], memory_after_model_load[0])
|
||||
|
||||
# Assert that the memory use of the second device is higher than it was when the entire model was loaded
|
||||
# on the other device.
|
||||
self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1])
|
||||
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallel_equal_results(self):
|
||||
if not self.test_model_parallel:
|
||||
pass
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_parallelizable_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
model = model_class(config)
|
||||
output = model(**inputs_dict)
|
||||
|
||||
model.parallelize()
|
||||
|
||||
def cast_to_gpu(dictionary):
|
||||
output = {}
|
||||
for k, v in dictionary.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
output[k] = v.to("cuda:0")
|
||||
else:
|
||||
output[k] = v
|
||||
|
||||
return output
|
||||
|
||||
parallel_output = model(**cast_to_gpu(inputs_dict))
|
||||
|
||||
for value, parallel_value in zip(output, parallel_output):
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.assertTrue(torch.allclose(value, parallel_value.to("cpu"), atol=1e-7))
|
||||
elif isinstance(value, (Tuple, List)):
|
||||
for value_, parallel_value_ in zip(value, parallel_value):
|
||||
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ class GPT2ModelTester:
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
|
||||
def get_large_model_config(self):
|
||||
return GPT2Config.from_pretrained("gpt2")
|
||||
|
||||
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
@@ -389,7 +392,9 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (GPT2LMHeadModel,) if is_torch_available() else ()
|
||||
test_missing_keys = False
|
||||
test_model_parallel = True
|
||||
|
||||
# special case for DoubleHeads model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -85,6 +85,9 @@ class T5ModelTester:
|
||||
self.scope = None
|
||||
self.decoder_layers = decoder_layers
|
||||
|
||||
def get_large_model_config(self):
|
||||
return T5Config.from_pretrained("t5-base")
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
@@ -470,9 +473,18 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (
|
||||
(
|
||||
T5Model,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = False
|
||||
test_model_parallel = True
|
||||
is_encoder_decoder = True
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user