[run_clm example] add torch_dtype option for model load. (#20971)
* [run_clm example] add torch_dtype option for model load. for BLOOM 175B model. peak memory will reduce about 350G for inference. the weight of BLOOM in model hub is bfloat16 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add other type in option * fix style Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from itertools import chain
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
@@ -119,6 +120,16 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
torch_dtype: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||||
|
"dtype will be automatically derived from the model's weights."
|
||||||
|
),
|
||||||
|
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
||||||
@@ -374,6 +385,11 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_args.model_name_or_path:
|
if model_args.model_name_or_path:
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
@@ -381,6 +397,7 @@ def main():
|
|||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user