[trainer] --model_parallel hasn't been implemented for most models (#9347)
* --model_parallel hasn't been implemented for most models * make the help clear as well * implement is_parallelizable; use it * oops * remove property
This commit is contained in:
@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||||
derived classes of the same architecture adding modules on top of the base model.
|
derived classes of the same architecture adding modules on top of the base model.
|
||||||
|
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
|
||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
@@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
# trained, but which are deterministic)
|
# trained, but which are deterministic)
|
||||||
_keys_to_ignore_on_save = None
|
_keys_to_ignore_on_save = None
|
||||||
|
|
||||||
|
is_parallelizable = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
config_class = GPT2Config
|
config_class = GPT2Config
|
||||||
load_tf_weights = load_tf_weights_in_gpt2
|
load_tf_weights = load_tf_weights_in_gpt2
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
is_parallelizable = True
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
|
|||||||
@@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||||||
config_class = T5Config
|
config_class = T5Config
|
||||||
load_tf_weights = load_tf_weights_in_t5
|
load_tf_weights = load_tf_weights_in_t5
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
is_parallelizable = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
|
|||||||
@@ -242,6 +242,11 @@ class Trainer:
|
|||||||
if model is None and model_init is not None:
|
if model is None and model_init is not None:
|
||||||
model = self.call_model_init()
|
model = self.call_model_init()
|
||||||
|
|
||||||
|
if self.args.model_parallel and not model.is_parallelizable:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
|
||||||
|
)
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
if model is not None and not self.args.model_parallel:
|
if model is not None and not self.args.model_parallel:
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
|
|||||||
@@ -207,8 +207,8 @@ class TrainingArguments:
|
|||||||
:obj:`"eval_loss"`.
|
:obj:`"eval_loss"`.
|
||||||
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
|
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
|
||||||
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
If there is more than one device, whether to use model parallelism to distribute the model's modules across
|
If the model supports model parallelism and there is more than one device, whether to use model parallelism
|
||||||
devices or not.
|
to distribute the model's modules across devices or not.
|
||||||
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||||
|
|||||||
Reference in New Issue
Block a user