From 9f675b05d4770655afa90ef51333aec032021a6b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 6 Jan 2021 03:50:11 -0800 Subject: [PATCH] [trainer] self.model_wrapped + _model_unwrap (#9390) * model wrapped + model_unwrap * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style * deprecation warning * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/trainer.py | 109 ++++++++++++++++++++---------------- tests/test_trainer.py | 5 +- 2 files changed, 63 insertions(+), 51 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f50f2a51db..0c0f8ed9fc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -162,6 +162,14 @@ if is_fairscale_available(): logger = logging.get_logger(__name__) +def _model_unwrap(model: nn.Module) -> nn.Module: + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return _model_unwrap(model.module) + else: + return model + + class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. @@ -212,6 +220,16 @@ class Trainer: containing the optimizer and the scheduler to use. Will default to an instance of :class:`~transformers.AdamW` on your model and a scheduler given by :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`. + + Important accessors: + + ``self.model`` - always points to the core model. If using a transformers model, it will be a + :class:`PreTrainedModel` subclass. + + ``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``, + the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model + hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. """ def __init__( @@ -234,30 +252,37 @@ class Trainer: self.args = args # Seed must be set before instantiating the model when using model set_seed(self.args.seed) - assert ( - model is not None or model_init is not None - ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument." - self.model_init = model_init self.hp_name = None - if model is None and model_init is not None: - model = self.call_model_init() - if self.args.model_parallel and not model.is_parallelizable: - raise ValueError( - f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" - ) + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. " + "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.", + FutureWarning, + ) + self.model_init = model_init - # Model parallel - if model is not None and not self.args.model_parallel: - model = model.to(args.device) - - self.model = model default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer + # Model parallel + if not self.args.model_parallel: + model = model.to(args.device) + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + self.compute_metrics = compute_metrics self.optimizer, self.lr_scheduler = optimizers if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): @@ -640,9 +665,11 @@ class Trainer: set_seed(self.args.seed) model = self.call_model_init(trial) - if not self.args.model_parallel: - self.model = model.to(self.args.device) + model = model.to(self.args.device) + + self.model = model + self.model_wrapped = model # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None @@ -681,8 +708,9 @@ class Trainer: # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(model_path) + model = self.model_wrapped + # Mixed precision training with apex (torch < 1.6) - model = self.model if self.use_apex: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) @@ -707,6 +735,14 @@ class Trainer: # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc. + # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() @@ -937,12 +973,10 @@ class Trainer: self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _save_checkpoint(self, model, trial, metrics=None): - # In all cases (even distributed/parallel), self.model is always a reference - # to the model we want to save. - if hasattr(model, "module"): - assert model.module is self.model, f"Module {model.module} should be a reference to self.model" - else: - assert model is self.model, f"Model {model} should be a reference to self.model" + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save. + assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" + # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" @@ -1630,30 +1664,7 @@ class Trainer: Returns: :obj:`int`: The number of floating-point operations. """ - - model = self._actual_model(self.model) - - if hasattr(model, "floating_point_ops"): - return model.floating_point_ops(inputs) - + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) else: return 0 - - @staticmethod - def _actual_model( - model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module] - ) -> torch.nn.modules.Module: - """ - - Args: - model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`): - Model object used during training - - Returns: - :obj:`torch.nn.modules.Module`: unwrapped module - """ - if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel): - model = model.module - else: - model = model - return model diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b95a21c653..0443a1429e 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -53,6 +53,7 @@ if is_torch_available(): Trainer, TrainerState, ) + from transformers.trainer import _model_unwrap PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" @@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase): trainer = get_regression_trainer(learning_rate=0.1) def assert_flos_extraction(trainer, wrapped_model_to_check): - self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check)) - self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0) + self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check)) + self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0) # with plain model assert_flos_extraction(trainer, trainer.model)