Rework TPU checkpointing in Trainer (#10504)
* Rework TPU checkpointing in Trainer * Wraps the barrier in a dist test * Address review comments * Remove line
This commit is contained in:
@@ -37,7 +37,6 @@ from .file_utils import (
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_remote_url,
|
||||
is_torch_tpu_available,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .generation_utils import GenerationMixin
|
||||
@@ -781,7 +780,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_config: bool = True,
|
||||
state_dict: Optional[dict] = None,
|
||||
save_function: Callable = torch.save,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
||||
@@ -789,19 +794,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Arguments:
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
|
||||
to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
|
||||
process to avoid race conditions.
|
||||
state_dict (nested dictionary of :obj:`torch.Tensor`):
|
||||
The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
|
||||
only save parts of the model or if special precautions need to be taken when recovering the state
|
||||
dictionary of a model (like when using model parallelism).
|
||||
save_function (:obj:`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace :obj:`torch.save` by another method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
model_to_save = unwrap_model(self)
|
||||
|
||||
# Attach architecture to the config
|
||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||
|
||||
state_dict = model_to_save.state_dict()
|
||||
# Save the config
|
||||
if save_config:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
|
||||
# Save the model
|
||||
if state_dict is None:
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Handle the case where some state_dict keys shouldn't be saved
|
||||
if self._keys_to_ignore_on_save is not None:
|
||||
@@ -809,18 +831,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
|
||||
if getattr(self.config, "xla_device", False) and is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
# Save configuration file
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
# xm.save takes care of saving only from master
|
||||
xm.save(state_dict, output_model_file)
|
||||
else:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
torch.save(state_dict, output_model_file)
|
||||
save_function(state_dict, output_model_file)
|
||||
|
||||
logger.info("Model weights saved in {}".format(output_model_file))
|
||||
|
||||
@@ -1181,12 +1192,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
||||
model.to(xm.xla_device())
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -1634,6 +1639,20 @@ class SequenceSummary(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""
|
||||
Recursively unwraps a model from potential containers (as used in distributed training).
|
||||
|
||||
Args:
|
||||
model (:obj:`torch.nn.Module`): The model to unwrap.
|
||||
"""
|
||||
# since there could be multiple levels of wrapping, unwrap recursively
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
|
||||
"""
|
||||
Prune a linear layer to keep only entries in index.
|
||||
|
||||
Reference in New Issue
Block a user