[trainer] --model_parallel hasn't been implemented for most models (#9347)

* --model_parallel hasn't been implemented for most models

* make the help clear as well

* implement is_parallelizable; use it

* oops

* remove property
This commit is contained in:
Stas Bekman
2021-01-05 01:01:30 -08:00
committed by GitHub
parent 4225740a7b
commit 748006c0b3
5 changed files with 12 additions and 2 deletions

View File

@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
"""
config_class = None
base_model_prefix = ""
@@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# trained, but which are deterministic)
_keys_to_ignore_on_save = None
is_parallelizable = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""

View File

@@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
config_class = GPT2Config
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
is_parallelizable = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel):
config_class = T5Config
load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer"
is_parallelizable = True
@property
def dummy_inputs(self):

View File

@@ -242,6 +242,11 @@ class Trainer:
if model is None and model_init is not None:
model = self.call_model_init()
if self.args.model_parallel and not model.is_parallelizable:
raise ValueError(
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
)
# Model parallel
if model is not None and not self.args.model_parallel:
model = model.to(args.device)

View File

@@ -207,8 +207,8 @@ class TrainingArguments:
:obj:`"eval_loss"`.
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there is more than one device, whether to use model parallelism to distribute the model's modules across
devices or not.
If the model supports model parallelism and there is more than one device, whether to use model parallelism
to distribute the model's modules across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping