make run_generation more generic for other devices (#25133)
* make run_generation more generic for other devices * use Accelerate to support any device type it supports. * make style * fix error usage of accelerator.prepare_model * use `PartialState` to make sure everything is running on the right device --------- Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
@@ -23,8 +23,9 @@ import inspect
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -88,13 +89,6 @@ the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famo
|
||||
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
#
|
||||
# Functions to prepare models' input
|
||||
#
|
||||
@@ -327,7 +321,11 @@ def main():
|
||||
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--use_cpu",
|
||||
action="store_true",
|
||||
help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available",
|
||||
)
|
||||
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
@@ -337,12 +335,13 @@ def main():
|
||||
parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
# Initialize the distributed state.
|
||||
distributed_state = PartialState(cpu=args.use_cpu)
|
||||
|
||||
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
|
||||
logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}")
|
||||
|
||||
set_seed(args)
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Initialize the model and tokenizer
|
||||
try:
|
||||
@@ -355,7 +354,9 @@ def main():
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
|
||||
# Set the model to the right device
|
||||
model.to(distributed_state.device)
|
||||
|
||||
if args.fp16:
|
||||
model.half()
|
||||
@@ -382,7 +383,7 @@ def main():
|
||||
else:
|
||||
prefix = args.prefix if args.prefix else args.padding_text
|
||||
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
encoded_prompt = encoded_prompt.to(args.device)
|
||||
encoded_prompt = encoded_prompt.to(distributed_state.device)
|
||||
|
||||
if encoded_prompt.size()[-1] == 0:
|
||||
input_ids = None
|
||||
|
||||
Reference in New Issue
Block a user