From 9c9fe89f84f7aa8ec29f19c39a1bf7f1bca82fc3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 3 Jan 2023 22:33:11 +0800 Subject: [PATCH] [run_clm example] add torch_dtype option for model load. (#20971) * [run_clm example] add torch_dtype option for model load. for BLOOM 175B model. peak memory will reduce about 350G for inference. the weight of BLOOM in model hub is bfloat16 Signed-off-by: Wang, Yi A * add other type in option * fix style Signed-off-by: Wang, Yi A --- examples/pytorch/language-modeling/run_clm.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index a7b527ab34..fc62c614bd 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -30,6 +30,7 @@ from itertools import chain from typing import Optional import datasets +import torch from datasets import load_dataset import evaluate @@ -119,6 +120,16 @@ class ModelArguments: ) }, ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -374,6 +385,11 @@ def main(): ) if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -381,6 +397,7 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + torch_dtype=torch_dtype, ) else: model = AutoModelForCausalLM.from_config(config)