[trainer] implement support for full fp16 in evaluation/predict (#10268)
* implement --fp16_full_eval * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style * add test Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -218,6 +218,8 @@ class Trainer:
|
|||||||
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
|
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
|
||||||
to :obj:`False` if model parallel or deepspeed is used, or if the default
|
to :obj:`False` if model parallel or deepspeed is used, or if the default
|
||||||
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
|
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
|
||||||
|
- **is_in_train** -- Whether or not a model is currently running ``train`` (e.g. when ``evaluate`` is called
|
||||||
|
while in ``train``)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -243,6 +245,7 @@ class Trainer:
|
|||||||
set_seed(self.args.seed)
|
set_seed(self.args.seed)
|
||||||
self.hp_name = None
|
self.hp_name = None
|
||||||
self.deepspeed = None
|
self.deepspeed = None
|
||||||
|
self.is_in_train = False
|
||||||
|
|
||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||||
@@ -273,7 +276,7 @@ class Trainer:
|
|||||||
|
|
||||||
# one place to sort out whether to place the model on device or not
|
# one place to sort out whether to place the model on device or not
|
||||||
self.place_model_on_device = args.place_model_on_device
|
self.place_model_on_device = args.place_model_on_device
|
||||||
if self.is_model_parallel or (args.deepspeed and args.do_train):
|
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
|
||||||
self.place_model_on_device = False
|
self.place_model_on_device = False
|
||||||
|
|
||||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
@@ -713,6 +716,10 @@ class Trainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True):
|
def _wrap_model(self, model, training=True):
|
||||||
|
# already initialized its own DDP and AMP
|
||||||
|
if self.deepspeed:
|
||||||
|
return model
|
||||||
|
|
||||||
# Mixed precision training with apex (torch < 1.6)
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
if self.use_apex and training:
|
if self.use_apex and training:
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
@@ -731,8 +738,6 @@ class Trainer:
|
|||||||
model = ShardedDDP(model, self.optimizer)
|
model = ShardedDDP(model, self.optimizer)
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_distributed_available():
|
||||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||||
elif self.deepspeed:
|
|
||||||
pass # already initialized its own DDP earlier
|
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
if self.args.ddp_find_unused_parameters is not None:
|
if self.args.ddp_find_unused_parameters is not None:
|
||||||
find_unused_parameters = self.args.ddp_find_unused_parameters
|
find_unused_parameters = self.args.ddp_find_unused_parameters
|
||||||
@@ -773,6 +778,8 @@ class Trainer:
|
|||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
|
|
||||||
|
self.is_in_train = True
|
||||||
|
|
||||||
if "model_path" in kwargs:
|
if "model_path" in kwargs:
|
||||||
resume_from_checkpoint = kwargs.pop("model_path")
|
resume_from_checkpoint = kwargs.pop("model_path")
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@@ -1088,6 +1095,12 @@ class Trainer:
|
|||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
self.model_wrapped = self.model
|
self.model_wrapped = self.model
|
||||||
gc.collect() # force memory release
|
gc.collect() # force memory release
|
||||||
|
# to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed
|
||||||
|
self.place_model_on_device = self.args.place_model_on_device
|
||||||
|
if self.is_model_parallel:
|
||||||
|
self.place_model_on_device = False
|
||||||
|
|
||||||
|
self.is_in_train = False
|
||||||
|
|
||||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||||
|
|
||||||
@@ -1689,6 +1702,11 @@ class Trainer:
|
|||||||
|
|
||||||
model = self._wrap_model(self.model, training=False)
|
model = self._wrap_model(self.model, training=False)
|
||||||
|
|
||||||
|
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
|
||||||
|
# ``train`` is running, half it first and then put on device
|
||||||
|
if not self.is_in_train and self.args.fp16_full_eval:
|
||||||
|
model = model.half().to(self.args.device)
|
||||||
|
|
||||||
batch_size = dataloader.batch_size
|
batch_size = dataloader.batch_size
|
||||||
num_examples = self.num_examples(dataloader)
|
num_examples = self.num_examples(dataloader)
|
||||||
logger.info("***** Running %s *****", description)
|
logger.info("***** Running %s *****", description)
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class TrainingArguments:
|
|||||||
:func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly
|
:func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly
|
||||||
initialized parameters.
|
initialized parameters.
|
||||||
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training.
|
Whether to use 16-bit (mixed) precision training instead of 32-bit training.
|
||||||
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
|
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
|
||||||
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
|
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
|
||||||
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
|
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
|
||||||
@@ -163,6 +163,9 @@ class TrainingArguments:
|
|||||||
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
|
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
|
||||||
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
|
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
|
||||||
other choices will force the requested backend.
|
other choices will force the requested backend.
|
||||||
|
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but
|
||||||
|
can harm metric values.
|
||||||
local_rank (:obj:`int`, `optional`, defaults to -1):
|
local_rank (:obj:`int`, `optional`, defaults to -1):
|
||||||
Rank of the process during distributed training.
|
Rank of the process during distributed training.
|
||||||
tpu_num_cores (:obj:`int`, `optional`):
|
tpu_num_cores (:obj:`int`, `optional`):
|
||||||
@@ -353,7 +356,7 @@ class TrainingArguments:
|
|||||||
|
|
||||||
fp16: bool = field(
|
fp16: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA Apex) instead of 32-bit"},
|
metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"},
|
||||||
)
|
)
|
||||||
fp16_opt_level: str = field(
|
fp16_opt_level: str = field(
|
||||||
default="O1",
|
default="O1",
|
||||||
@@ -368,6 +371,10 @@ class TrainingArguments:
|
|||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
|
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
|
||||||
)
|
)
|
||||||
|
fp16_full_eval: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"},
|
||||||
|
)
|
||||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||||
|
|
||||||
tpu_num_cores: Optional[int] = field(
|
tpu_num_cores: Optional[int] = field(
|
||||||
@@ -488,8 +495,10 @@ class TrainingArguments:
|
|||||||
if self.run_name is None:
|
if self.run_name is None:
|
||||||
self.run_name = self.output_dir
|
self.run_name = self.output_dir
|
||||||
|
|
||||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval):
|
||||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
raise ValueError(
|
||||||
|
"Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices."
|
||||||
|
)
|
||||||
if self.report_to is None:
|
if self.report_to is None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -29,6 +30,7 @@ from transformers.testing_utils import (
|
|||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -912,6 +914,62 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer = get_regression_trainer(skip_memory_metrics=True)
|
trainer = get_regression_trainer(skip_memory_metrics=True)
|
||||||
self.check_mem_metrics(trainer, self.assertNotIn)
|
self.check_mem_metrics(trainer, self.assertNotIn)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_fp16_full_eval(self):
|
||||||
|
|
||||||
|
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
|
||||||
|
# it's using pretty large safety margins, but small enough to detect broken functionality.
|
||||||
|
debug = 0
|
||||||
|
|
||||||
|
bs = 8
|
||||||
|
# make the params somewhat big so that there will be enough RAM consumed to be able to
|
||||||
|
# measure things. We should get about 64KB for a+b in fp32
|
||||||
|
a = torch.ones(1000, bs) + 0.001
|
||||||
|
b = torch.ones(1000, bs) - 0.001
|
||||||
|
|
||||||
|
# 1. with mem metrics enabled
|
||||||
|
trainer = get_regression_trainer(a=a, b=b, eval_len=16)
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
del trainer
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
fp32_init = metrics["init_mem_gpu_alloc_delta"]
|
||||||
|
fp32_eval = metrics["eval_mem_gpu_alloc_delta"]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f"fp32_init {fp32_init}")
|
||||||
|
print(f"fp32_eval {fp32_eval}")
|
||||||
|
|
||||||
|
# here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram.
|
||||||
|
# perfect world: fp32_init == 64<<10
|
||||||
|
self.assertGreater(fp32_init, 59_000)
|
||||||
|
# after eval should be no extra memory allocated - with a small margin (other than the peak
|
||||||
|
# memory consumption for the forward calculation that gets recovered)
|
||||||
|
# perfect world: fp32_eval == close to zero
|
||||||
|
self.assertLess(fp32_eval, 5_000)
|
||||||
|
|
||||||
|
# 2. with mem metrics disabled
|
||||||
|
trainer = get_regression_trainer(a=a, b=b, eval_len=16, fp16_full_eval=True)
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
fp16_init = metrics["init_mem_gpu_alloc_delta"]
|
||||||
|
fp16_eval = metrics["eval_mem_gpu_alloc_delta"]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f"fp16_init {fp16_init}")
|
||||||
|
print(f"fp16_eval {fp16_eval}")
|
||||||
|
|
||||||
|
# here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0
|
||||||
|
# perfect world: fp16_init == close to zero
|
||||||
|
self.assertLess(fp16_init, 5_000)
|
||||||
|
# here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back)
|
||||||
|
# perfect world: fp32_init == 32<<10
|
||||||
|
self.assertGreater(fp16_eval, 27_000)
|
||||||
|
|
||||||
|
# 3. relative comparison fp32 vs full fp16
|
||||||
|
# should be about half of fp16_init
|
||||||
|
# perfect world: fp32_init/2 == fp16_eval
|
||||||
|
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
|
|||||||
Reference in New Issue
Block a user