[s2s] two stage run_distributed_eval.py (#7105)
This commit is contained in:
46
examples/seq2seq/aggregate_distributed_results.py
Normal file
46
examples/seq2seq/aggregate_distributed_results.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
|
||||||
|
except ImportError:
|
||||||
|
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
|
||||||
|
|
||||||
|
|
||||||
|
def combine_partial_results(
|
||||||
|
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
|
||||||
|
):
|
||||||
|
"""Write first n lines of each file f in src_dir to dest_dir/f """
|
||||||
|
src_dir = Path(result_dir)
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
paths_to_combine = list(src_dir.glob("rank*.json"))
|
||||||
|
records = []
|
||||||
|
for partial_result in paths_to_combine:
|
||||||
|
records.extend(load_json(partial_result))
|
||||||
|
preds = [x["pred"] for x in records]
|
||||||
|
labels = [x["label"] for x in records]
|
||||||
|
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||||
|
metrics = score_fn(preds, labels)
|
||||||
|
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
|
||||||
|
print(metrics)
|
||||||
|
if just_metrics:
|
||||||
|
return
|
||||||
|
|
||||||
|
if save_prefix is None:
|
||||||
|
save_prefix = "generated"
|
||||||
|
print("using generated as prefix")
|
||||||
|
|
||||||
|
tgt_path = save_dir.joinpath(f"{save_prefix}.target")
|
||||||
|
write_txt_file(labels, tgt_path)
|
||||||
|
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
|
||||||
|
write_txt_file(preds, pred_path)
|
||||||
|
if "source" in records[0]:
|
||||||
|
src_path = save_dir.joinpath(f"{save_prefix}.source")
|
||||||
|
write_txt_file([x["source"] for x in records], src_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(combine_partial_results)
|
||||||
139
examples/seq2seq/run_distributed_eval.py
Normal file
139
examples/seq2seq/run_distributed_eval.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||||
|
except ImportError:
|
||||||
|
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||||
|
|
||||||
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def eval_data_dir(
|
||||||
|
data_dir,
|
||||||
|
save_dir: str,
|
||||||
|
model_name: str,
|
||||||
|
bs: int = 8,
|
||||||
|
max_source_length: int = 1024,
|
||||||
|
type_path="val",
|
||||||
|
n_obs=None,
|
||||||
|
fp16=False,
|
||||||
|
save_source=False,
|
||||||
|
num_beams: int = 4,
|
||||||
|
task="summarization",
|
||||||
|
local_rank=None,
|
||||||
|
**generate_kwargs,
|
||||||
|
) -> Dict:
|
||||||
|
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||||
|
model_name = str(model_name)
|
||||||
|
assert local_rank is not None
|
||||||
|
torch.distributed.init_process_group(backend="nccl", rank=local_rank)
|
||||||
|
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
|
||||||
|
if fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
|
||||||
|
use_task_specific_params(model, task) # update config with task specific params
|
||||||
|
ds = Seq2SeqDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir,
|
||||||
|
max_source_length,
|
||||||
|
max_target_length=1024,
|
||||||
|
type_path=type_path,
|
||||||
|
n_obs=n_obs,
|
||||||
|
prefix=model.config.prefix,
|
||||||
|
)
|
||||||
|
sampler = ds.make_sortish_sampler(bs, distributed=True)
|
||||||
|
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
||||||
|
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
|
||||||
|
results = []
|
||||||
|
for batch in tqdm(data_loader):
|
||||||
|
summaries = model.generate(
|
||||||
|
input_ids=batch["input_ids"].to(model.device),
|
||||||
|
attention_mask=batch["attention_mask"].to(model.device),
|
||||||
|
num_beams=num_beams,
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
|
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
|
||||||
|
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
|
||||||
|
if save_source:
|
||||||
|
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
|
||||||
|
for i in range(len(labels)):
|
||||||
|
label, pred = labels[i], preds[i]
|
||||||
|
if save_source:
|
||||||
|
results.append(dict(pred=pred, label=label, source=docs[i]))
|
||||||
|
else:
|
||||||
|
results.append(dict(pred=pred, label=label))
|
||||||
|
save_json(results, save_path)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_generate():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
|
||||||
|
)
|
||||||
|
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
help="like facebook/bart-large-cnn,t5-base, etc.",
|
||||||
|
default="sshleifer/distilbart-xsum-12-3",
|
||||||
|
)
|
||||||
|
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
|
||||||
|
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
|
||||||
|
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||||
|
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
||||||
|
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||||
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||||
|
parser.add_argument(
|
||||||
|
"--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
|
||||||
|
)
|
||||||
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
parser.add_argument("--save_source", action="store_true")
|
||||||
|
|
||||||
|
args, rest = parser.parse_known_args()
|
||||||
|
parsed = parse_numeric_cl_kwargs(rest)
|
||||||
|
if parsed:
|
||||||
|
print(f"parsed the following generate kwargs: {parsed}")
|
||||||
|
Path(args.save_dir).mkdir(exist_ok=True)
|
||||||
|
if args.reference_path is None and Path(args.score_path).exists():
|
||||||
|
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
|
||||||
|
eval_data_dir(
|
||||||
|
args.input_path,
|
||||||
|
args.save_dir,
|
||||||
|
args.model_name,
|
||||||
|
prefix=args.prefix,
|
||||||
|
batch_size=args.bs,
|
||||||
|
fp16=args.fp16,
|
||||||
|
task=args.task,
|
||||||
|
local_rank=args.local_rank,
|
||||||
|
n_obs=args.n_obs,
|
||||||
|
save_source=args.save_source,
|
||||||
|
**parsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Usage for MT:
|
||||||
|
run_generate()
|
||||||
@@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl
|
|||||||
|
|
||||||
result[unparsed_args[i][2:]] = value
|
result[unparsed_args[i][2:]] = value
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def write_txt_file(ordered_tgt, path):
|
||||||
|
f = Path(path).open("w")
|
||||||
|
for ln in ordered_tgt:
|
||||||
|
f.write(ln + "\n")
|
||||||
|
f.flush()
|
||||||
|
|||||||
Reference in New Issue
Block a user