[Deepspeed] add support for bf16 mode (#14569)
* [WIP] add support for bf16 mode * prep for bf16 * prep for bf16 * fix; zero2/bf16 is ok * check bf16 is available * test fixes * enable zero3_bf16 * config files * docs * split stage_dtype; merge back to non-dtype-specific config file * fix doc * cleanup * cleanup * bfloat16 => bf16 to match the PR changes * s/zero_gather_fp16_weights_on_model_save/zero_gather_16bit_weights_on_model_save/; s/save_fp16_model/save_16bit_model/ * test fixes/skipping * move * fix * Update docs/source/main_classes/deepspeed.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * backticks * cleanup * cleanup * cleanup * new version * add note about grad accum in bf16 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -205,8 +205,19 @@ task_cmds = make_task_cmds()
|
||||
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
|
||||
stages = [ZERO2, ZERO3]
|
||||
|
||||
# future preparation:
|
||||
# for now test just fp16, as these tests are quite slow
|
||||
# FP16 = "fp16"
|
||||
# BF16 = "bf16"
|
||||
#
|
||||
# dtypes = [FP16]
|
||||
# so just hardcoding --fp16 for now
|
||||
# if is_torch_bf16_available():
|
||||
# dtypes += [BF16]
|
||||
|
||||
|
||||
def parameterized_custom_name_func(func, param_num, param):
|
||||
# customize the test name generator function as we want both params to appear in the sub-test
|
||||
|
||||
Reference in New Issue
Block a user