[trainer] self.model_wrapped + _model_unwrap (#9390)
* model wrapped + model_unwrap * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style * deprecation warning * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -162,6 +162,14 @@ if is_fairscale_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_unwrap(model: nn.Module) -> nn.Module:
|
||||||
|
# since there could be multiple levels of wrapping, unwrap recursively
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
return _model_unwrap(model.module)
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
||||||
@@ -212,6 +220,16 @@ class Trainer:
|
|||||||
containing the optimizer and the scheduler to use. Will default to an instance of
|
containing the optimizer and the scheduler to use. Will default to an instance of
|
||||||
:class:`~transformers.AdamW` on your model and a scheduler given by
|
:class:`~transformers.AdamW` on your model and a scheduler given by
|
||||||
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
||||||
|
|
||||||
|
Important accessors:
|
||||||
|
|
||||||
|
``self.model`` - always points to the core model. If using a transformers model, it will be a
|
||||||
|
:class:`PreTrainedModel` subclass.
|
||||||
|
|
||||||
|
``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the
|
||||||
|
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
|
||||||
|
the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model
|
||||||
|
hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -234,30 +252,37 @@ class Trainer:
|
|||||||
self.args = args
|
self.args = args
|
||||||
# Seed must be set before instantiating the model when using model
|
# Seed must be set before instantiating the model when using model
|
||||||
set_seed(self.args.seed)
|
set_seed(self.args.seed)
|
||||||
assert (
|
|
||||||
model is not None or model_init is not None
|
|
||||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
|
||||||
self.model_init = model_init
|
|
||||||
self.hp_name = None
|
self.hp_name = None
|
||||||
if model is None and model_init is not None:
|
|
||||||
|
if model is None:
|
||||||
|
if model_init is not None:
|
||||||
|
self.model_init = model_init
|
||||||
model = self.call_model_init()
|
model = self.call_model_init()
|
||||||
|
else:
|
||||||
if self.args.model_parallel and not model.is_parallelizable:
|
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
|
||||||
raise ValueError(
|
else:
|
||||||
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
|
if model_init is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`Trainer` requires either a `model` or `model_init` argument, but not both. "
|
||||||
|
"`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
|
||||||
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
self.model_init = model_init
|
||||||
|
|
||||||
# Model parallel
|
|
||||||
if model is not None and not self.args.model_parallel:
|
|
||||||
model = model.to(args.device)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
if not self.args.model_parallel:
|
||||||
|
model = model.to(args.device)
|
||||||
|
|
||||||
|
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
|
||||||
|
self.model_wrapped = model
|
||||||
|
self.model = model
|
||||||
|
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.optimizer, self.lr_scheduler = optimizers
|
self.optimizer, self.lr_scheduler = optimizers
|
||||||
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
||||||
@@ -640,9 +665,11 @@ class Trainer:
|
|||||||
set_seed(self.args.seed)
|
set_seed(self.args.seed)
|
||||||
|
|
||||||
model = self.call_model_init(trial)
|
model = self.call_model_init(trial)
|
||||||
|
|
||||||
if not self.args.model_parallel:
|
if not self.args.model_parallel:
|
||||||
self.model = model.to(self.args.device)
|
model = model.to(self.args.device)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.model_wrapped = model
|
||||||
|
|
||||||
# Reinitializes optimizer and scheduler
|
# Reinitializes optimizer and scheduler
|
||||||
self.optimizer, self.lr_scheduler = None, None
|
self.optimizer, self.lr_scheduler = None, None
|
||||||
@@ -681,8 +708,9 @@ class Trainer:
|
|||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
self._load_optimizer_and_scheduler(model_path)
|
self._load_optimizer_and_scheduler(model_path)
|
||||||
|
|
||||||
|
model = self.model_wrapped
|
||||||
|
|
||||||
# Mixed precision training with apex (torch < 1.6)
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
model = self.model
|
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|
||||||
@@ -707,6 +735,14 @@ class Trainer:
|
|||||||
# find_unused_parameters breaks checkpointing as per
|
# find_unused_parameters breaks checkpointing as per
|
||||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
|
||||||
|
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||||
|
if model is not self.model:
|
||||||
|
self.model_wrapped = model
|
||||||
|
|
||||||
|
# important: at this point:
|
||||||
|
# self.model is the Transformers Model
|
||||||
|
# self.model_wrapped is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc.
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||||
@@ -937,12 +973,10 @@ class Trainer:
|
|||||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# In all cases (even distributed/parallel), self.model is always a reference
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
# to the model we want to save.
|
# want to save.
|
||||||
if hasattr(model, "module"):
|
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
||||||
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
|
|
||||||
else:
|
|
||||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
|
||||||
@@ -1630,30 +1664,7 @@ class Trainer:
|
|||||||
Returns:
|
Returns:
|
||||||
:obj:`int`: The number of floating-point operations.
|
:obj:`int`: The number of floating-point operations.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self.model, "floating_point_ops"):
|
||||||
model = self._actual_model(self.model)
|
return self.model.floating_point_ops(inputs)
|
||||||
|
|
||||||
if hasattr(model, "floating_point_ops"):
|
|
||||||
return model.floating_point_ops(inputs)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _actual_model(
|
|
||||||
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
|
|
||||||
) -> torch.nn.modules.Module:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
|
|
||||||
Model object used during training
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`torch.nn.modules.Module`: unwrapped module
|
|
||||||
"""
|
|
||||||
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
|
||||||
model = model.module
|
|
||||||
else:
|
|
||||||
model = model
|
|
||||||
return model
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ if is_torch_available():
|
|||||||
Trainer,
|
Trainer,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer import _model_unwrap
|
||||||
|
|
||||||
|
|
||||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||||
@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
|
|
||||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||||
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
|
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
|
||||||
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||||
|
|
||||||
# with plain model
|
# with plain model
|
||||||
assert_flos_extraction(trainer, trainer.model)
|
assert_flos_extraction(trainer, trainer.model)
|
||||||
|
|||||||
Reference in New Issue
Block a user