Rework TPU checkpointing in Trainer (#10504)
* Rework TPU checkpointing in Trainer * Wraps the barrier in a dist test * Address review comments * Remove line
This commit is contained in:
@@ -75,8 +75,6 @@ class PretrainedConfig(object):
|
|||||||
heads to prune in said layer.
|
heads to prune in said layer.
|
||||||
|
|
||||||
For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||||
xla_device (:obj:`bool`, `optional`):
|
|
||||||
A flag to indicate if TPU are available or not.
|
|
||||||
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
|
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
|
||||||
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of :obj:`0` means
|
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of :obj:`0` means
|
||||||
that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes
|
that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes
|
||||||
@@ -248,7 +246,11 @@ class PretrainedConfig(object):
|
|||||||
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
||||||
|
|
||||||
# TPU arguments
|
# TPU arguments
|
||||||
self.xla_device = kwargs.pop("xla_device", None)
|
if kwargs.pop("xla_device", None) is not None:
|
||||||
|
logger.warn(
|
||||||
|
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
|
||||||
|
"safely remove it from your `config.json` file."
|
||||||
|
)
|
||||||
|
|
||||||
# Name or path to the pretrained checkpoint
|
# Name or path to the pretrained checkpoint
|
||||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from .file_utils import (
|
|||||||
cached_path,
|
cached_path,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_torch_tpu_available,
|
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .generation_utils import GenerationMixin
|
from .generation_utils import GenerationMixin
|
||||||
@@ -781,7 +780,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
self.base_model._prune_heads(heads_to_prune)
|
self.base_model._prune_heads(heads_to_prune)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
save_config: bool = True,
|
||||||
|
state_dict: Optional[dict] = None,
|
||||||
|
save_function: Callable = torch.save,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||||
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
||||||
@@ -789,19 +794,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
Arguments:
|
Arguments:
|
||||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||||
Directory to which to save. Will be created if it doesn't exist.
|
Directory to which to save. Will be created if it doesn't exist.
|
||||||
|
save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
|
||||||
|
to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
|
||||||
|
process to avoid race conditions.
|
||||||
|
state_dict (nested dictionary of :obj:`torch.Tensor`):
|
||||||
|
The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
|
||||||
|
only save parts of the model or if special precautions need to be taken when recovering the state
|
||||||
|
dictionary of a model (like when using model parallelism).
|
||||||
|
save_function (:obj:`Callable`):
|
||||||
|
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||||
|
need to replace :obj:`torch.save` by another method.
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
# Only save the model itself if we are using distributed training
|
# Only save the model itself if we are using distributed training
|
||||||
model_to_save = self.module if hasattr(self, "module") else self
|
model_to_save = unwrap_model(self)
|
||||||
|
|
||||||
# Attach architecture to the config
|
# Attach architecture to the config
|
||||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||||
|
|
||||||
state_dict = model_to_save.state_dict()
|
# Save the config
|
||||||
|
if save_config:
|
||||||
|
model_to_save.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = model_to_save.state_dict()
|
||||||
|
|
||||||
# Handle the case where some state_dict keys shouldn't be saved
|
# Handle the case where some state_dict keys shouldn't be saved
|
||||||
if self._keys_to_ignore_on_save is not None:
|
if self._keys_to_ignore_on_save is not None:
|
||||||
@@ -809,18 +831,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
|
save_function(state_dict, output_model_file)
|
||||||
if getattr(self.config, "xla_device", False) and is_torch_tpu_available():
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
if xm.is_master_ordinal():
|
|
||||||
# Save configuration file
|
|
||||||
model_to_save.config.save_pretrained(save_directory)
|
|
||||||
# xm.save takes care of saving only from master
|
|
||||||
xm.save(state_dict, output_model_file)
|
|
||||||
else:
|
|
||||||
model_to_save.config.save_pretrained(save_directory)
|
|
||||||
torch.save(state_dict, output_model_file)
|
|
||||||
|
|
||||||
logger.info("Model weights saved in {}".format(output_model_file))
|
logger.info("Model weights saved in {}".format(output_model_file))
|
||||||
|
|
||||||
@@ -1181,12 +1192,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
}
|
}
|
||||||
return model, loading_info
|
return model, loading_info
|
||||||
|
|
||||||
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
|
||||||
model.to(xm.xla_device())
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -1634,6 +1639,20 @@ class SequenceSummary(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Recursively unwraps a model from potential containers (as used in distributed training).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`torch.nn.Module`): The model to unwrap.
|
||||||
|
"""
|
||||||
|
# since there could be multiple levels of wrapping, unwrap recursively
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
return unwrap_model(model.module)
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
|
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
|
||||||
"""
|
"""
|
||||||
Prune a linear layer to keep only entries in index.
|
Prune a linear layer to keep only entries in index.
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ from .file_utils import (
|
|||||||
is_sagemaker_distributed_available,
|
is_sagemaker_distributed_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||||
from .optimization import Adafactor, AdamW, get_scheduler
|
from .optimization import Adafactor, AdamW, get_scheduler
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
@@ -154,14 +154,6 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _model_unwrap(model: nn.Module) -> nn.Module:
|
|
||||||
# since there could be multiple levels of wrapping, unwrap recursively
|
|
||||||
if hasattr(model, "module"):
|
|
||||||
return _model_unwrap(model.module)
|
|
||||||
else:
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
||||||
@@ -359,10 +351,6 @@ class Trainer:
|
|||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||||
if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
|
|
||||||
# Set an xla_device flag on the model's config.
|
|
||||||
# We'll find a more elegant and not need to do this in the future.
|
|
||||||
self.model.config.xla_device = True
|
|
||||||
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||||
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
||||||
|
|
||||||
@@ -1194,7 +1182,7 @@ class Trainer:
|
|||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
# want to save except FullyShardedDDP.
|
# want to save except FullyShardedDDP.
|
||||||
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||||
|
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
@@ -1499,13 +1487,15 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
||||||
|
|
||||||
Will only save from the world_master process (unless in TPUs).
|
Will only save from the main process.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
self._save_tpu(output_dir)
|
self._save_tpu(output_dir)
|
||||||
elif self.is_world_process_zero():
|
else:
|
||||||
self._save(output_dir)
|
if self.is_world_process_zero():
|
||||||
|
self._save(output_dir)
|
||||||
|
if self.args.local_rank != -1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
def _save_tpu(self, output_dir: Optional[str] = None):
|
def _save_tpu(self, output_dir: Optional[str] = None):
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
@@ -1519,34 +1509,39 @@ class Trainer:
|
|||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
xm.rendezvous("saving_checkpoint")
|
xm.rendezvous("saving_checkpoint")
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
if isinstance(_model_unwrap(self.model), PreTrainedModel):
|
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||||
if xm.is_master_ordinal():
|
unwrap_model(self.model).save_pretrained(
|
||||||
_model_unwrap(self.model).config.save_pretrained(output_dir)
|
output_dir,
|
||||||
|
save_config=self.is_world_process_zero(),
|
||||||
|
state_dict=self.model.state_dict(),
|
||||||
|
save_function=xm.save,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
else:
|
else:
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
|
||||||
if self.tokenizer is not None and self.is_world_process_zero():
|
if self.tokenizer is not None and self.is_world_process_zero():
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None):
|
def _save(self, output_dir: Optional[str] = None):
|
||||||
|
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
if isinstance(_model_unwrap(self.model), PreTrainedModel):
|
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||||
_model_unwrap(self.model).config.save_pretrained(output_dir)
|
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict())
|
||||||
else:
|
else:
|
||||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
else:
|
else:
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir)
|
||||||
if self.tokenizer is not None and self.is_world_process_zero():
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ if is_torch_available():
|
|||||||
Trainer,
|
Trainer,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from transformers.trainer import _model_unwrap
|
from transformers.modeling_utils import unwrap_model
|
||||||
|
|
||||||
|
|
||||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||||
@@ -882,8 +882,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
|
|
||||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||||
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
|
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
|
||||||
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
|
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||||
|
|
||||||
# with plain model
|
# with plain model
|
||||||
assert_flos_extraction(trainer, trainer.model)
|
assert_flos_extraction(trainer, trainer.model)
|
||||||
|
|||||||
Reference in New Issue
Block a user