WIP: Support for Training with BF16 (#13207)
* started bf16 integration * minor changes * code now runs * style * lay foundation for bf16 testing * lay foundation for bf16 testing * start the tests * better bf16 check * style * 2 separate checkers - one for bf16 support, another for bf16+autocast * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * a couple of comment resolutions * more comment resolutions * resolved a small bug * just some print statemtns * added todo marking * added a todo * adjust for API change s/fast_dtype/dtype/ * fix style * merge 2 bf16 util functions * bf16 now does scaling too * Add support for bfloat16 * Revert T5 layernorm to float32 This is based on the comment at https://github.com/huggingface/transformers/pull/14448/files#r752660929 and the PyTorch PR https://github.com/pytorch/pytorch/pull/66920 . * Add comment about conversion to float32 before returning the numpy data * Add comment about AMP-bfloat16 incompatibility * Fix formatting * typo * reformer / bf16 * cleanup * require at least pt-1.10 * fix * will deal with deepspeed separately * cleanup * revert * cleanup * fp16_full_eval and bf16_full_eval are separate modes * proper deprecation * cleanup * test and fixes * spelling * cleanup * add a note that this API is experimental Co-authored-by: jamie <jamie@cortx.com> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: suriya <suriya@cortx.com> Co-authored-by: Manuel R. Ciosici <manuelrciosici@gmail.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from .file_utils import (
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_vision_available,
|
||||
@@ -493,6 +494,14 @@ def require_torch_gpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
|
||||
if not is_torch_bf16_available():
|
||||
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_datasets(test_case):
|
||||
"""Decorator marking a test that requires datasets."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user