[trainer] --model_parallel hasn't been implemented for most models (#9347)

* --model_parallel hasn't been implemented for most models

* make the help clear as well

* implement is_parallelizable; use it

* oops

* remove property
This commit is contained in:
Stas Bekman
2021-01-05 01:01:30 -08:00
committed by GitHub
parent 4225740a7b
commit 748006c0b3
5 changed files with 12 additions and 2 deletions

View File

@@ -242,6 +242,11 @@ class Trainer:
if model is None and model_init is not None:
model = self.call_model_init()
if self.args.model_parallel and not model.is_parallelizable:
raise ValueError(
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
)
# Model parallel
if model is not None and not self.args.model_parallel:
model = model.to(args.device)