From de9e2979647338bc9617dae68c5e9dccc413fb9f Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 13 Sep 2020 23:40:38 -0400 Subject: [PATCH] [s2s] distributed eval cleanup (#7110) --- examples/seq2seq/run_distributed_eval.py | 23 +++++++++++++---------- examples/seq2seq/utils.py | 3 ++- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py index ea1b3ed524..2478edfdc6 100644 --- a/examples/seq2seq/run_distributed_eval.py +++ b/examples/seq2seq/run_distributed_eval.py @@ -1,5 +1,4 @@ import argparse -import warnings from logging import getLogger from pathlib import Path from typing import Dict @@ -18,6 +17,7 @@ try: except ImportError: from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params + DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -51,6 +51,8 @@ def eval_data_dir( tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. use_task_specific_params(model, task) # update config with task specific params + if max_source_length is None: + max_source_length = tokenizer.model_max_length ds = Seq2SeqDataset( tokenizer, data_dir, @@ -97,9 +99,11 @@ def run_generate(): default="sshleifer/distilbart-xsum-12-3", ) parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") - parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test") + parser.add_argument("--max_source_length", type=int, default=None) + parser.add_argument( + "--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test" + ) parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") - parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics") parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument( @@ -113,24 +117,23 @@ def run_generate(): parser.add_argument("--save_source", action="store_true") args, rest = parser.parse_known_args() - parsed = parse_numeric_cl_kwargs(rest) - if parsed: - print(f"parsed the following generate kwargs: {parsed}") + generate_kwargs = parse_numeric_cl_kwargs(rest) + if generate_kwargs: + print(f"parsed the following generate kwargs: {generate_kwargs}") Path(args.save_dir).mkdir(exist_ok=True) - if args.reference_path is None and Path(args.score_path).exists(): - warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.") eval_data_dir( args.input_path, args.save_dir, args.model_name, - prefix=args.prefix, + type_path=args.type_path, batch_size=args.bs, fp16=args.fp16, task=args.task, local_rank=args.local_rank, n_obs=args.n_obs, save_source=args.save_source, - **parsed, + max_source_length=args.max_source_length, + **generate_kwargs, ) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index b732cde0bf..37f49b8073 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset): self.max_target_length = max_target_length assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" self.tokenizer = tokenizer - self.prefix = prefix + self.prefix = prefix if prefix is not None else "" + if n_obs is not None: self.src_lens = self.src_lens[:n_obs] self.pad_token_id = self.tokenizer.pad_token_id