From 31245775e5772fbded1ac07ed89fbba3b5af0cb9 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 11 Feb 2021 18:44:18 -0500 Subject: [PATCH] Add SageMakerTrainer for model paralellism (#10122) * Refactor things out of main train * Store signature * Add SageMakerTrainer * Init + Copyright * Address review comments --- src/transformers/sagemaker/__init__.py | 20 ++ src/transformers/sagemaker/trainer_sm.py | 178 ++++++++++++++++++ .../sagemaker/training_args_sm.py | 89 +++++++++ src/transformers/trainer.py | 105 ++++++----- src/transformers/training_args.py | 7 + 5 files changed, 349 insertions(+), 50 deletions(-) create mode 100644 src/transformers/sagemaker/__init__.py create mode 100644 src/transformers/sagemaker/trainer_sm.py create mode 100644 src/transformers/sagemaker/training_args_sm.py diff --git a/src/transformers/sagemaker/__init__.py b/src/transformers/sagemaker/__init__.py new file mode 100644 index 0000000000..46222fdf7c --- /dev/null +++ b/src/transformers/sagemaker/__init__.py @@ -0,0 +1,20 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainer_sm import SageMakerTrainer +from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available diff --git a/src/transformers/sagemaker/trainer_sm.py b/src/transformers/sagemaker/trainer_sm.py new file mode 100644 index 0000000000..63b16ab227 --- /dev/null +++ b/src/transformers/sagemaker/trainer_sm.py @@ -0,0 +1,178 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data.dataset import Dataset +from torch.utils.data.distributed import DistributedSampler + +from ..trainer import Trainer +from ..trainer_pt_utils import ( + DistributedLengthGroupedSampler, + SequentialDistributedSampler, + nested_detach, + nested_numpify, +) +from ..utils import logging +from .training_args_sm import is_smdistributed_available + + +logger = logging.get_logger(__name__) + + +if is_smdistributed_available(): + import smdistributed.modelparallel.torch as smp + + @smp.step() + def forward_backward(model, inputs): + outputs = model(**inputs) + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + model.backward(loss) + return loss + + @smp.step() + def forward_only(model, inputs): + return model(**inputs) + + def smp_gather(tensor): + if isinstance(tensor, (list, tuple)): + return type(tensor)(smp_gather(t) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: smp_gather(v) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP) + return torch.cat([t.cpu() for t in all_tensors], dim=0) + + def nested_smp_concat(tensor): + if isinstance(tensor, (list, tuple)): + return type(tensor)(nested_smp_concat(t) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: nested_smp_concat(v) for k, v in tensor.items()}) + # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` + # which is also the name of the decorator so Python is confused. + return tensor.concat().detach().cpu() + + +class SageMakerTrainer(Trainer): + def __init__(self, args=None, **kwargs): + super().__init__(args=args, **kwargs) + self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != "" + if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1: + raise ValueError("Gradient accumulation is not supported when model parallel is enabled.") + + def _get_train_sampler(self): + if self.is_model_parallel_enabled: + if self.args.group_by_length: + return DistributedLengthGroupedSampler( + self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank() + ) + else: + return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) + else: + return super()._get_train_sampler() + + def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: + if self.is_model_parallel_enabled: + return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) + else: + return super()._get_eval_sampler(eval_dataset) + + def _wrap_model(self, model, training=True): + if self.is_model_parallel_enabled: + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model) + else: + return super()._wrap_model(model) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + super().create_optimizer_and_scheduler(num_training_steps) + if self.is_model_parallel_enabled: + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + if self.is_model_parallel_enabled: + model.train() + inputs = self._prepare_inputs(inputs) + loss_mb = forward_backward(model, inputs) + return loss_mb.reduce_mean().detach().to(self.args.device) + else: + return super().training_step(model, inputs) + + def _gather_and_numpify(self, tensors, name): + if self.is_model_parallel_enabled: + tensors = smp_gather(tensors) + return nested_numpify(tensors) + else: + return super()._gather_and_numpify(tensors, name) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + if self.is_model_parallel_enabled: + has_labels = all(inputs.get(k) is not None for k in self.label_names) + inputs = self._prepare_inputs(inputs) + + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + raw_outputs = forward_only(model, inputs) + if has_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = nested_smp_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = nested_smp_concat(logits_mb) + + if prediction_loss_only: + return (loss, None, None) + + if len(logits) == 1: + logits = logits[0] + + if has_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + return (loss, logits, labels) + else: + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) diff --git a/src/transformers/sagemaker/training_args_sm.py b/src/transformers/sagemaker/training_args_sm.py new file mode 100644 index 0000000000..0aaef833ca --- /dev/null +++ b/src/transformers/sagemaker/training_args_sm.py @@ -0,0 +1,89 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from dataclasses import dataclass, field + +import torch + +from transformers.file_utils import cached_property, is_sagemaker_distributed_available +from transformers.training_args import TrainingArguments +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +def is_smdistributed_available(): + return importlib.util.find_spec("smdistributed") is not None + + +if is_smdistributed_available(): + import smdistributed.modelparallel.torch as smp + + +@dataclass +class SageMakerTrainingArguments(TrainingArguments): + mp_parameters: str = field( + default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."} + ) + + def __post_init__(self): + super().__post_init__() + if is_smdistributed_available() and self.mp_parameters != "": + smp.init() + + @cached_property + def _setup_devices(self) -> "torch.device": + logger.info("PyTorch: setting up devices") + if self.no_cuda: + device = torch.device("cpu") + self._n_gpu = 0 + elif is_smdistributed_available() and self.mp_parameters != "": + local_rank = smp.local_rank() + device = torch.device("cuda", local_rank) + self._n_gpu = 1 + elif is_sagemaker_distributed_available(): + import smdistributed.dataparallel.torch.distributed as dist + + dist.init_process_group() + self.local_rank = dist.get_local_rank() + device = torch.device("cuda", self.local_rank) + self._n_gpu = 1 + elif self.local_rank == -1: + # if n_gpu is > 1 we'll use nn.DataParallel. + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will + # trigger an error that a device index is missing. Index 0 takes into account the + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` + # will use the first GPU in that env, i.e. GPU#1 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at + # the default value. + self._n_gpu = torch.cuda.device_count() + else: + # Here, we'll use torch.distributed. + # Initializes the distributed backend which will take care of synchronizing nodes/GPUs + torch.distributed.init_process_group(backend="nccl") + device = torch.device("cuda", self.local_rank) + self._n_gpu = 1 + + if device.type == "cuda": + torch.cuda.set_device(device) + + return device + + @property + def place_model_on_device(self): + return not (is_smdistributed_available() and self.mp_parameters != "") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e7619b362a..858622d09e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -272,7 +272,7 @@ class Trainer: # 1. MP - since we are trying to fit a much bigger than 1 gpu model # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment - if not self.is_model_parallel and not (args.deepspeed and args.do_train): + if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device: model = model.to(args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs @@ -319,6 +319,7 @@ class Trainer: if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") + self._signature_columns = None if is_datasets_available(): if isinstance(train_dataset, datasets.Dataset): self._remove_unused_columns(self.train_dataset, description="training") @@ -425,16 +426,18 @@ class Trainer: def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: return - # Inspect model forward signature to keep only the arguments it accepts. - signature = inspect.signature(self.model.forward) - signature_columns = list(signature.parameters.keys()) - # Labels may be named label or label_ids, the default data collator handles that. - signature_columns += ["label", "label_ids"] - columns = [k for k in signature_columns if k in dataset.column_names] - ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += ["label", "label_ids"] + columns = [k for k in self._signature_columns if k in dataset.column_names] + ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) dset_description = "" if description is None else f"in the {description} set " logger.info( - f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." + f"The following columns {dset_description}don't have a corresponding argument in " + f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." ) dataset.set_format(type=dataset.format["type"], columns=columns) @@ -684,6 +687,45 @@ class Trainer: return model + def _wrap_model(self, model, training=True): + # Mixed precision training with apex (torch < 1.6) + if self.use_apex and training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) + if self.args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + if self.sharded_dpp: + model = ShardedDDP(model, self.optimizer) + elif is_sagemaker_distributed_available(): + model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) + elif self.deepspeed: + pass # already initialized its own DDP earlier + elif self.args.local_rank != -1: + if self.args.ddp_find_unused_parameters is not None: + find_unused_parameters = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) + else: + find_unused_parameters = True + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.args.local_rank], + output_device=self.args.local_rank, + find_unused_parameters=find_unused_parameters, + ) + + return model + def train( self, resume_from_checkpoint: Optional[str] = None, @@ -736,7 +778,7 @@ class Trainer: # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: - if not self.is_model_parallel: + if not self.is_model_parallel and self.args.place_model_on_device: self.model = self.model.to(self.args.device) self.model_wrapped = self.model @@ -783,38 +825,7 @@ class Trainer: # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) - model = self.model_wrapped - - # Mixed precision training with apex (torch < 1.6) - if self.use_apex: - model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) - - # Multi-gpu training (should be after apex fp16 initialization) - if self.args.n_gpu > 1: - model = torch.nn.DataParallel(model) - - # Distributed training (should be after apex fp16 initialization) - if self.sharded_dpp: - model = ShardedDDP(model, self.optimizer) - elif is_sagemaker_distributed_available(): - model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) - elif self.deepspeed: - pass # already initialized its own DDP earlier - elif self.args.local_rank != -1: - if self.args.ddp_find_unused_parameters is not None: - find_unused_parameters = self.args.ddp_find_unused_parameters - elif isinstance(model, PreTrainedModel): - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) - else: - find_unused_parameters = True - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[self.args.local_rank], - output_device=self.args.local_rank, - find_unused_parameters=find_unused_parameters, - ) + model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: @@ -1020,7 +1031,7 @@ class Trainer: ) if isinstance(self.model, PreTrainedModel): self.model = self.model.from_pretrained(self.state.best_model_checkpoint) - if not self.is_model_parallel: + if not self.is_model_parallel and self.args.place_model_on_device: self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) @@ -1610,13 +1621,7 @@ class Trainer: # flagging only for when --do_train wasn't passed as only then it's redundant logger.info("Detected the deepspeed argument but it will not be used for evaluation") - model = self.model - - # multi-gpu eval - if self.args.n_gpu > 1: - model = torch.nn.DataParallel(model) - # Note: in torch.distributed mode, there's no point in wrapping the model - # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + model = self._wrap_model(self.model, training=False) batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1054929b4e..396cef24f9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -637,6 +637,13 @@ class TrainingArguments: else: return ParallelMode.NOT_PARALLEL + @property + def place_model_on_device(self): + """ + Can be subclassed and overridden for some specific integrations. + """ + return True + def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support).