[trainer] self.model_wrapped + _model_unwrap (#9390)
* model wrapped + model_unwrap * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style * deprecation warning * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -162,6 +162,14 @@ if is_fairscale_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _model_unwrap(model: nn.Module) -> nn.Module:
|
||||
# since there could be multiple levels of wrapping, unwrap recursively
|
||||
if hasattr(model, "module"):
|
||||
return _model_unwrap(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
||||
@@ -212,6 +220,16 @@ class Trainer:
|
||||
containing the optimizer and the scheduler to use. Will default to an instance of
|
||||
:class:`~transformers.AdamW` on your model and a scheduler given by
|
||||
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
||||
|
||||
Important accessors:
|
||||
|
||||
``self.model`` - always points to the core model. If using a transformers model, it will be a
|
||||
:class:`PreTrainedModel` subclass.
|
||||
|
||||
``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the
|
||||
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
|
||||
the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model
|
||||
hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -234,30 +252,37 @@ class Trainer:
|
||||
self.args = args
|
||||
# Seed must be set before instantiating the model when using model
|
||||
set_seed(self.args.seed)
|
||||
assert (
|
||||
model is not None or model_init is not None
|
||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||
self.model_init = model_init
|
||||
self.hp_name = None
|
||||
if model is None and model_init is not None:
|
||||
|
||||
if model is None:
|
||||
if model_init is not None:
|
||||
self.model_init = model_init
|
||||
model = self.call_model_init()
|
||||
|
||||
if self.args.model_parallel and not model.is_parallelizable:
|
||||
raise ValueError(
|
||||
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
|
||||
else:
|
||||
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
|
||||
else:
|
||||
if model_init is not None:
|
||||
warnings.warn(
|
||||
"`Trainer` requires either a `model` or `model_init` argument, but not both. "
|
||||
"`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.model_init = model_init
|
||||
|
||||
# Model parallel
|
||||
if model is not None and not self.args.model_parallel:
|
||||
model = model.to(args.device)
|
||||
|
||||
self.model = model
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Model parallel
|
||||
if not self.args.model_parallel:
|
||||
model = model.to(args.device)
|
||||
|
||||
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
|
||||
self.model_wrapped = model
|
||||
self.model = model
|
||||
|
||||
self.compute_metrics = compute_metrics
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
||||
@@ -640,9 +665,11 @@ class Trainer:
|
||||
set_seed(self.args.seed)
|
||||
|
||||
model = self.call_model_init(trial)
|
||||
|
||||
if not self.args.model_parallel:
|
||||
self.model = model.to(self.args.device)
|
||||
model = model.to(self.args.device)
|
||||
|
||||
self.model = model
|
||||
self.model_wrapped = model
|
||||
|
||||
# Reinitializes optimizer and scheduler
|
||||
self.optimizer, self.lr_scheduler = None, None
|
||||
@@ -681,8 +708,9 @@ class Trainer:
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(model_path)
|
||||
|
||||
model = self.model_wrapped
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
model = self.model
|
||||
if self.use_apex:
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
@@ -707,6 +735,14 @@ class Trainer:
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
|
||||
# important: at this point:
|
||||
# self.model is the Transformers Model
|
||||
# self.model_wrapped is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc.
|
||||
|
||||
# Train!
|
||||
if is_torch_tpu_available():
|
||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||
@@ -937,12 +973,10 @@ class Trainer:
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases (even distributed/parallel), self.model is always a reference
|
||||
# to the model we want to save.
|
||||
if hasattr(model, "module"):
|
||||
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
|
||||
else:
|
||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save.
|
||||
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
||||
|
||||
# Save model checkpoint
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
@@ -1630,30 +1664,7 @@ class Trainer:
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
model = self._actual_model(self.model)
|
||||
|
||||
if hasattr(model, "floating_point_ops"):
|
||||
return model.floating_point_ops(inputs)
|
||||
|
||||
if hasattr(self.model, "floating_point_ops"):
|
||||
return self.model.floating_point_ops(inputs)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _actual_model(
|
||||
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
|
||||
) -> torch.nn.modules.Module:
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
|
||||
Model object used during training
|
||||
|
||||
Returns:
|
||||
:obj:`torch.nn.modules.Module`: unwrapped module
|
||||
"""
|
||||
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
else:
|
||||
model = model
|
||||
return model
|
||||
|
||||
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
from transformers.trainer import _model_unwrap
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
||||
Reference in New Issue
Block a user