[trainer] refactor place_model_on_device logic, add deepspeed (#10243)
* refactor place_model_on_device logic, add deepspeed * doc * style
This commit is contained in:
@@ -214,6 +214,10 @@ class Trainer:
|
|||||||
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
|
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
|
||||||
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
|
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
|
||||||
data parallelism, this means some of the model layers are split on different GPUs).
|
data parallelism, this means some of the model layers are split on different GPUs).
|
||||||
|
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
|
||||||
|
to :obj:`False` if model parallel or deepspeed is used, or if the default
|
||||||
|
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -262,6 +266,11 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
self.is_model_parallel = False
|
self.is_model_parallel = False
|
||||||
|
|
||||||
|
# one place to sort out whether to place the model on device or not
|
||||||
|
self.place_model_on_device = args.place_model_on_device
|
||||||
|
if self.is_model_parallel or (args.deepspeed and args.do_train):
|
||||||
|
self.place_model_on_device = False
|
||||||
|
|
||||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
@@ -272,7 +281,7 @@ class Trainer:
|
|||||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||||
# and we only use deepspeed for training at the moment
|
# and we only use deepspeed for training at the moment
|
||||||
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
|
if self.place_model_on_device:
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
|
|
||||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||||
@@ -780,7 +789,7 @@ class Trainer:
|
|||||||
|
|
||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
if model_reloaded:
|
if model_reloaded:
|
||||||
if not self.is_model_parallel and self.args.place_model_on_device:
|
if self.place_model_on_device:
|
||||||
self.model = self.model.to(self.args.device)
|
self.model = self.model.to(self.args.device)
|
||||||
self.model_wrapped = self.model
|
self.model_wrapped = self.model
|
||||||
|
|
||||||
@@ -1033,7 +1042,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
if isinstance(self.model, PreTrainedModel):
|
if isinstance(self.model, PreTrainedModel):
|
||||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
||||||
if not self.is_model_parallel and self.args.place_model_on_device:
|
if self.place_model_on_device:
|
||||||
self.model = self.model.to(self.args.device)
|
self.model = self.model.to(self.args.device)
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||||
|
|||||||
Reference in New Issue
Block a user