From b4fd49b6c54ac34d45cc656f2872b5f392029590 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 19 Apr 2024 18:05:34 +0200 Subject: [PATCH] Update unwrap from accelerate (#29933) * Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller * style * put else --------- Co-authored-by: Zach Mueller --- src/transformers/modeling_utils.py | 27 ++++++++++++++++++++++----- src/transformers/trainer.py | 20 ++++++++++---------- tests/trainer/test_trainer.py | 7 ++++--- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e4fcd2ebc1..e4fee8a526 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -109,6 +109,7 @@ if is_accelerate_available(): from accelerate.hooks import add_hook_to_module from accelerate.utils import ( check_tied_parameters_on_same_device, + extract_model_from_parallel, find_tied_parameters, get_balanced_memory, get_max_memory, @@ -4805,18 +4806,34 @@ class SequenceSummary(nn.Module): return output -def unwrap_model(model: nn.Module) -> nn.Module: +def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). Args: model (`torch.nn.Module`): The model to unwrap. + recursive (`bool`, *optional*, defaults to `False`): + Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers + recursively, not just the top-level distributed containers. """ - # since there could be multiple levels of wrapping, unwrap recursively - if hasattr(model, "module"): - return unwrap_model(model.module) + # Use accelerate implementation if available (should always be the case when using torch) + # This is for pytorch, as we also have to handle things like dynamo + if is_accelerate_available(): + kwargs = {} + if recursive: + if not is_accelerate_available("0.29.0"): + raise RuntimeError( + "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate" + ) + else: + kwargs["recursive"] = recursive + return extract_model_from_parallel(model, **kwargs) else: - return model + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model def expand_device_map(device_map, param_names, start_prefix): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 92025cb979..f911e1c894 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -63,7 +63,7 @@ from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_h from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary -from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from .modeling_utils import PreTrainedModel, load_sharded_checkpoint from .models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES, @@ -684,7 +684,7 @@ class Trainer: Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914 """ - unwrapped_model = unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) if _is_peft_model(unwrapped_model): embeddings = unwrapped_model.base_model.model.get_input_embeddings() @@ -705,7 +705,7 @@ class Trainer: if not hasattr(self, "neftune_hook_handle"): raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") - unwrapped_model = unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) if _is_peft_model(unwrapped_model): embeddings = unwrapped_model.base_model.model.get_input_embeddings() @@ -1617,7 +1617,7 @@ class Trainer: return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again - if unwrap_model(model) is not model: + if self.accelerator.unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) @@ -3165,7 +3165,7 @@ class Trainer: self._past = outputs[self.args.past_index] if labels is not None: - unwrapped_model = unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) if _is_peft_model(unwrapped_model): model_name = unwrapped_model.base_model.model._get_name() else: @@ -3272,8 +3272,8 @@ class Trainer: supported_classes = (PushToHubMixin,) xm.rendezvous("saving_checkpoint") if not isinstance(model, supported_classes): - if isinstance(unwrap_model(model), supported_classes): - unwrap_model(model).save_pretrained( + if isinstance(self.accelerator.unwrap_model(model), supported_classes): + self.accelerator.unwrap_model(model).save_pretrained( output_dir, is_main_process=self.args.should_save, state_dict=model.state_dict(), @@ -3311,8 +3311,8 @@ class Trainer: if state_dict is None: state_dict = self.model.state_dict() - if isinstance(unwrap_model(self.model), supported_classes): - unwrap_model(self.model).save_pretrained( + if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): + self.accelerator.unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: @@ -3969,7 +3969,7 @@ class Trainer: f.write(model_card) if is_peft_library: - unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5619a5c98c..8913de4db1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -123,7 +123,6 @@ if is_torch_available(): Trainer, TrainerState, ) - from transformers.modeling_utils import unwrap_model from transformers.trainer_pt_utils import AcceleratorConfig if is_safetensors_available(): @@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer = get_regression_trainer(learning_rate=0.1) def assert_flos_extraction(trainer, wrapped_model_to_check): - self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check)) - self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0) + self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check)) + self.assertGreaterEqual( + getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0 + ) # with plain model assert_flos_extraction(trainer, trainer.model)