[s2s] distributed eval cleanup (#7110)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
@@ -18,6 +17,7 @@ try:
|
||||
except ImportError:
|
||||
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ def eval_data_dir(
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
|
||||
use_task_specific_params(model, task) # update config with task specific params
|
||||
if max_source_length is None:
|
||||
max_source_length = tokenizer.model_max_length
|
||||
ds = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir,
|
||||
@@ -97,9 +99,11 @@ def run_generate():
|
||||
default="sshleifer/distilbart-xsum-12-3",
|
||||
)
|
||||
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
|
||||
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
|
||||
parser.add_argument("--max_source_length", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
|
||||
)
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
||||
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument(
|
||||
@@ -113,24 +117,23 @@ def run_generate():
|
||||
parser.add_argument("--save_source", action="store_true")
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
parsed = parse_numeric_cl_kwargs(rest)
|
||||
if parsed:
|
||||
print(f"parsed the following generate kwargs: {parsed}")
|
||||
generate_kwargs = parse_numeric_cl_kwargs(rest)
|
||||
if generate_kwargs:
|
||||
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
||||
Path(args.save_dir).mkdir(exist_ok=True)
|
||||
if args.reference_path is None and Path(args.score_path).exists():
|
||||
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
|
||||
eval_data_dir(
|
||||
args.input_path,
|
||||
args.save_dir,
|
||||
args.model_name,
|
||||
prefix=args.prefix,
|
||||
type_path=args.type_path,
|
||||
batch_size=args.bs,
|
||||
fp16=args.fp16,
|
||||
task=args.task,
|
||||
local_rank=args.local_rank,
|
||||
n_obs=args.n_obs,
|
||||
save_source=args.save_source,
|
||||
**parsed,
|
||||
max_source_length=args.max_source_length,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix
|
||||
self.prefix = prefix if prefix is not None else ""
|
||||
|
||||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
Reference in New Issue
Block a user