[trainer] --model_parallel hasn't been implemented for most models (#9347)
* --model_parallel hasn't been implemented for most models * make the help clear as well * implement is_parallelizable; use it * oops * remove property
This commit is contained in:
@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
@@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
# trained, but which are deterministic)
|
||||
_keys_to_ignore_on_save = None
|
||||
|
||||
is_parallelizable = False
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
@@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
config_class = GPT2Config
|
||||
load_tf_weights = load_tf_weights_in_gpt2
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
config_class = T5Config
|
||||
load_tf_weights = load_tf_weights_in_t5
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
||||
@@ -242,6 +242,11 @@ class Trainer:
|
||||
if model is None and model_init is not None:
|
||||
model = self.call_model_init()
|
||||
|
||||
if self.args.model_parallel and not model.is_parallelizable:
|
||||
raise ValueError(
|
||||
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
|
||||
)
|
||||
|
||||
# Model parallel
|
||||
if model is not None and not self.args.model_parallel:
|
||||
model = model.to(args.device)
|
||||
|
||||
@@ -207,8 +207,8 @@ class TrainingArguments:
|
||||
:obj:`"eval_loss"`.
|
||||
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
|
||||
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If there is more than one device, whether to use model parallelism to distribute the model's modules across
|
||||
devices or not.
|
||||
If the model supports model parallelism and there is more than one device, whether to use model parallelism
|
||||
to distribute the model's modules across devices or not.
|
||||
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||
|
||||
Reference in New Issue
Block a user