From 6290169eb3391d72d9a08cab5c54a54b73a87463 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 4 Mar 2021 11:46:11 -0500 Subject: [PATCH] Rework TPU checkpointing in Trainer (#10504) * Rework TPU checkpointing in Trainer * Wraps the barrier in a dist test * Address review comments * Remove line --- src/transformers/configuration_utils.py | 8 +-- src/transformers/modeling_utils.py | 65 ++++++++++++++++--------- src/transformers/trainer.py | 53 +++++++++----------- tests/test_trainer.py | 6 +-- 4 files changed, 74 insertions(+), 58 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 763aecd7e2..0d0b410e0b 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -75,8 +75,6 @@ class PretrainedConfig(object): heads to prune in said layer. For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. - xla_device (:obj:`bool`, `optional`): - A flag to indicate if TPU are available or not. chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`): The chunk size of all feed forward layers in the residual attention blocks. A chunk size of :obj:`0` means that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes @@ -248,7 +246,11 @@ class PretrainedConfig(object): self.task_specific_params = kwargs.pop("task_specific_params", None) # TPU arguments - self.xla_device = kwargs.pop("xla_device", None) + if kwargs.pop("xla_device", None) is not None: + logger.warn( + "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " + "safely remove it from your `config.json` file." + ) # Name or path to the pretrained checkpoint self._name_or_path = str(kwargs.pop("name_or_path", "")) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 16a5f0452d..40601bdf31 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -37,7 +37,6 @@ from .file_utils import ( cached_path, hf_bucket_url, is_remote_url, - is_torch_tpu_available, replace_return_docstrings, ) from .generation_utils import GenerationMixin @@ -781,7 +780,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): self.base_model._prune_heads(heads_to_prune) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. @@ -789,19 +794,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): Arguments: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory to which to save. Will be created if it doesn't exist. + save_config (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to save the config of the model. Useful when in distributed training like TPUs and need + to call this function on all processes. In this case, set :obj:`save_config=True` only on the main + process to avoid race conditions. + state_dict (nested dictionary of :obj:`torch.Tensor`): + The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to + only save parts of the model or if special precautions need to be taken when recovering the state + dictionary of a model (like when using model parallelism). + save_function (:obj:`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace :obj:`torch.save` by another method. """ if os.path.isfile(save_directory): - logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) # Only save the model itself if we are using distributed training - model_to_save = self.module if hasattr(self, "module") else self + model_to_save = unwrap_model(self) # Attach architecture to the config model_to_save.config.architectures = [model_to_save.__class__.__name__] - state_dict = model_to_save.state_dict() + # Save the config + if save_config: + model_to_save.config.save_pretrained(save_directory) + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() # Handle the case where some state_dict keys shouldn't be saved if self._keys_to_ignore_on_save is not None: @@ -809,18 +831,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - - if getattr(self.config, "xla_device", False) and is_torch_tpu_available(): - import torch_xla.core.xla_model as xm - - if xm.is_master_ordinal(): - # Save configuration file - model_to_save.config.save_pretrained(save_directory) - # xm.save takes care of saving only from master - xm.save(state_dict, output_model_file) - else: - model_to_save.config.save_pretrained(save_directory) - torch.save(state_dict, output_model_file) + save_function(state_dict, output_model_file) logger.info("Model weights saved in {}".format(output_model_file)) @@ -1181,12 +1192,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): } return model, loading_info - if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available(): - import torch_xla.core.xla_model as xm - - model = xm.send_cpu_data_to_device(model, xm.xla_device()) - model.to(xm.xla_device()) - return model @@ -1634,6 +1639,20 @@ class SequenceSummary(nn.Module): return output +def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (:obj:`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model + + def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear: """ Prune a linear layer to keep only entries in index. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0c597f1ab1..d1076bff1c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -60,7 +60,7 @@ from .file_utils import ( is_sagemaker_distributed_available, is_torch_tpu_available, ) -from .modeling_utils import PreTrainedModel +from .modeling_utils import PreTrainedModel, unwrap_model from .optimization import Adafactor, AdamW, get_scheduler from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( @@ -154,14 +154,6 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def _model_unwrap(model: nn.Module) -> nn.Module: - # since there could be multiple levels of wrapping, unwrap recursively - if hasattr(model, "module"): - return _model_unwrap(model.module) - else: - return model - - class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. @@ -359,10 +351,6 @@ class Trainer: # Create output directory if needed if self.is_world_process_zero(): os.makedirs(self.args.output_dir, exist_ok=True) - if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel): - # Set an xla_device flag on the model's config. - # We'll find a more elegant and not need to do this in the future. - self.model.config.xla_device = True if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") @@ -1194,7 +1182,7 @@ class Trainer: def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. - # assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" @@ -1499,13 +1487,15 @@ class Trainer: """ Will save the model, so you can reload it using :obj:`from_pretrained()`. - Will only save from the world_master process (unless in TPUs). + Will only save from the main process. """ - if is_torch_tpu_available(): self._save_tpu(output_dir) - elif self.is_world_process_zero(): - self._save(output_dir) + else: + if self.is_world_process_zero(): + self._save(output_dir) + if self.args.local_rank != -1: + dist.barrier() def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir @@ -1519,34 +1509,39 @@ class Trainer: # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") if not isinstance(self.model, PreTrainedModel): - if isinstance(_model_unwrap(self.model), PreTrainedModel): - if xm.is_master_ordinal(): - _model_unwrap(self.model).config.save_pretrained(output_dir) + if isinstance(unwrap_model(self.model), PreTrainedModel): + unwrap_model(self.model).save_pretrained( + output_dir, + save_config=self.is_world_process_zero(), + state_dict=self.model.state_dict(), + save_function=xm.save, + ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = self.model.state_dict() - xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + state_dict = self.model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir) + self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save) if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir) def _save(self, output_dir: Optional[str] = None): + # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): - if isinstance(_model_unwrap(self.model), PreTrainedModel): - _model_unwrap(self.model).config.save_pretrained(output_dir) + if isinstance(unwrap_model(self.model), PreTrainedModel): + unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict()) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = self.model.state_dict() - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained(output_dir) - if self.tokenizer is not None and self.is_world_process_zero(): + if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 65f303244d..105cedd4de 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -57,7 +57,7 @@ if is_torch_available(): Trainer, TrainerState, ) - from transformers.trainer import _model_unwrap + from transformers.modeling_utils import unwrap_model PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" @@ -882,8 +882,8 @@ class TrainerIntegrationTest(unittest.TestCase): trainer = get_regression_trainer(learning_rate=0.1) def assert_flos_extraction(trainer, wrapped_model_to_check): - self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check)) - self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0) + self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check)) + self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0) # with plain model assert_flos_extraction(trainer, trainer.model)