WIP: Support for Training with BF16 (#13207)

* started bf16 integration

* minor changes

* code now runs

* style

* lay foundation for bf16 testing

* lay foundation for bf16 testing

* start the tests

* better bf16 check

* style

* 2 separate checkers - one for bf16 support, another for bf16+autocast

* Update src/transformers/training_args.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* a couple of comment resolutions

* more comment resolutions

* resolved a small bug

* just some print statemtns

* added todo marking

* added a todo

* adjust for API change s/fast_dtype/dtype/

* fix style

* merge 2 bf16 util functions

* bf16 now does scaling too

* Add support for bfloat16

* Revert T5 layernorm to float32

This is based on the comment at https://github.com/huggingface/transformers/pull/14448/files#r752660929 and the PyTorch PR https://github.com/pytorch/pytorch/pull/66920 .

* Add comment about conversion to float32 before returning the numpy data

* Add comment about AMP-bfloat16 incompatibility

* Fix formatting

* typo

* reformer / bf16

* cleanup

* require at least pt-1.10

* fix

* will deal with deepspeed separately

* cleanup

* revert

* cleanup

* fp16_full_eval and bf16_full_eval are separate modes

* proper deprecation

* cleanup

* test and fixes

* spelling

* cleanup

* add a note that this API is experimental

Co-authored-by: jamie <jamie@cortx.com>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: suriya <suriya@cortx.com>
Co-authored-by: Manuel R. Ciosici <manuelrciosici@gmail.com>
This commit is contained in:
Jamie DeAntonis
2021-11-30 21:00:47 -05:00
committed by GitHub
parent fc1d97f29d
commit 70996a5420
8 changed files with 246 additions and 66 deletions

View File

@@ -49,6 +49,7 @@ from .file_utils import (
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_tpu_available,
is_torchaudio_available,
is_vision_available,
@@ -493,6 +494,14 @@ def require_torch_gpu(test_case):
return test_case
def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
else:
return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""