[s2s] distributed eval cleanup (#7110)

This commit is contained in:
Sam Shleifer
2020-09-13 23:40:38 -04:00
committed by GitHub
parent 54395d87a6
commit de9e297964
2 changed files with 15 additions and 11 deletions

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import warnings
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
@@ -18,6 +17,7 @@ try:
except ImportError: except ImportError:
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -51,6 +51,8 @@ def eval_data_dir(
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
use_task_specific_params(model, task) # update config with task specific params use_task_specific_params(model, task) # update config with task specific params
if max_source_length is None:
max_source_length = tokenizer.model_max_length
ds = Seq2SeqDataset( ds = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir, data_dir,
@@ -97,9 +99,11 @@ def run_generate():
default="sshleifer/distilbart-xsum-12-3", default="sshleifer/distilbart-xsum-12-3",
) )
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test") parser.add_argument("--max_source_length", type=int, default=None)
parser.add_argument(
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument( parser.add_argument(
@@ -113,24 +117,23 @@ def run_generate():
parser.add_argument("--save_source", action="store_true") parser.add_argument("--save_source", action="store_true")
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest) generate_kwargs = parse_numeric_cl_kwargs(rest)
if parsed: if generate_kwargs:
print(f"parsed the following generate kwargs: {parsed}") print(f"parsed the following generate kwargs: {generate_kwargs}")
Path(args.save_dir).mkdir(exist_ok=True) Path(args.save_dir).mkdir(exist_ok=True)
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
eval_data_dir( eval_data_dir(
args.input_path, args.input_path,
args.save_dir, args.save_dir,
args.model_name, args.model_name,
prefix=args.prefix, type_path=args.type_path,
batch_size=args.bs, batch_size=args.bs,
fp16=args.fp16, fp16=args.fp16,
task=args.task, task=args.task,
local_rank=args.local_rank, local_rank=args.local_rank,
n_obs=args.n_obs, n_obs=args.n_obs,
save_source=args.save_source, save_source=args.save_source,
**parsed, max_source_length=args.max_source_length,
**generate_kwargs,
) )

View File

@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
self.max_target_length = max_target_length self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.prefix = prefix self.prefix = prefix if prefix is not None else ""
if n_obs is not None: if n_obs is not None:
self.src_lens = self.src_lens[:n_obs] self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id