From 748006c0b35d64cdee23a3cdc2107a1ce64044b5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 5 Jan 2021 01:01:30 -0800 Subject: [PATCH] [trainer] --model_parallel hasn't been implemented for most models (#9347) * --model_parallel hasn't been implemented for most models * make the help clear as well * implement is_parallelizable; use it * oops * remove property --- src/transformers/modeling_utils.py | 3 +++ src/transformers/models/gpt2/modeling_gpt2.py | 1 + src/transformers/models/t5/modeling_t5.py | 1 + src/transformers/trainer.py | 5 +++++ src/transformers/training_args.py | 4 ++-- 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fba7aa89cb..d0fc1ad0f4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + - **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. """ config_class = None base_model_prefix = "" @@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): # trained, but which are deterministic) _keys_to_ignore_on_save = None + is_parallelizable = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index bb8046c0e2..867a02d361 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel): config_class = GPT2Config load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" + is_parallelizable = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ce2be3c62..00d9ca30ec 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel): config_class = T5Config load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" + is_parallelizable = True @property def dummy_inputs(self): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f3c21e2d61..f50f2a51db 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -242,6 +242,11 @@ class Trainer: if model is None and model_init is not None: model = self.call_model_init() + if self.args.model_parallel and not model.is_parallelizable: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) + # Model parallel if model is not None and not self.args.model_parallel: model = model.to(args.device) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 9d78ce41fe..8ac8eb88a0 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -207,8 +207,8 @@ class TrainingArguments: :obj:`"eval_loss"`. - :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`. model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): - If there is more than one device, whether to use model parallelism to distribute the model's modules across - devices or not. + If the model supports model parallelism and there is more than one device, whether to use model parallelism + to distribute the model's modules across devices or not. ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`): When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping