[s2s] two stage run_distributed_eval.py (#7105)
This commit is contained in:
139
examples/seq2seq/run_distributed_eval.py
Normal file
139
examples/seq2seq/run_distributed_eval.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import argparse
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||
except ImportError:
|
||||
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def eval_data_dir(
|
||||
data_dir,
|
||||
save_dir: str,
|
||||
model_name: str,
|
||||
bs: int = 8,
|
||||
max_source_length: int = 1024,
|
||||
type_path="val",
|
||||
n_obs=None,
|
||||
fp16=False,
|
||||
save_source=False,
|
||||
num_beams: int = 4,
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||
model_name = str(model_name)
|
||||
assert local_rank is not None
|
||||
torch.distributed.init_process_group(backend="nccl", rank=local_rank)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
|
||||
torch.cuda.set_device(local_rank)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
|
||||
use_task_specific_params(model, task) # update config with task specific params
|
||||
ds = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length=1024,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
prefix=model.config.prefix,
|
||||
)
|
||||
sampler = ds.make_sortish_sampler(bs, distributed=True)
|
||||
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
||||
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
|
||||
results = []
|
||||
for batch in tqdm(data_loader):
|
||||
summaries = model.generate(
|
||||
input_ids=batch["input_ids"].to(model.device),
|
||||
attention_mask=batch["attention_mask"].to(model.device),
|
||||
num_beams=num_beams,
|
||||
**generate_kwargs,
|
||||
)
|
||||
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
|
||||
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
|
||||
if save_source:
|
||||
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
|
||||
for i in range(len(labels)):
|
||||
label, pred = labels[i], preds[i]
|
||||
if save_source:
|
||||
results.append(dict(pred=pred, label=label, source=docs[i]))
|
||||
else:
|
||||
results.append(dict(pred=pred, label=label))
|
||||
save_json(results, save_path)
|
||||
return results
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
|
||||
)
|
||||
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
help="like facebook/bart-large-cnn,t5-base, etc.",
|
||||
default="sshleifer/distilbart-xsum-12-3",
|
||||
)
|
||||
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
|
||||
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
||||
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument(
|
||||
"--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--save_source", action="store_true")
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
parsed = parse_numeric_cl_kwargs(rest)
|
||||
if parsed:
|
||||
print(f"parsed the following generate kwargs: {parsed}")
|
||||
Path(args.save_dir).mkdir(exist_ok=True)
|
||||
if args.reference_path is None and Path(args.score_path).exists():
|
||||
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
|
||||
eval_data_dir(
|
||||
args.input_path,
|
||||
args.save_dir,
|
||||
args.model_name,
|
||||
prefix=args.prefix,
|
||||
batch_size=args.bs,
|
||||
fp16=args.fp16,
|
||||
task=args.task,
|
||||
local_rank=args.local_rank,
|
||||
n_obs=args.n_obs,
|
||||
save_source=args.save_source,
|
||||
**parsed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Usage for MT:
|
||||
run_generate()
|
||||
Reference in New Issue
Block a user