[s2s] distributed eval cleanup (#7110)
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -18,6 +17,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +51,8 @@ def eval_data_dir(
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
|
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
|
||||||
use_task_specific_params(model, task) # update config with task specific params
|
use_task_specific_params(model, task) # update config with task specific params
|
||||||
|
if max_source_length is None:
|
||||||
|
max_source_length = tokenizer.model_max_length
|
||||||
ds = Seq2SeqDataset(
|
ds = Seq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir,
|
data_dir,
|
||||||
@@ -97,9 +99,11 @@ def run_generate():
|
|||||||
default="sshleifer/distilbart-xsum-12-3",
|
default="sshleifer/distilbart-xsum-12-3",
|
||||||
)
|
)
|
||||||
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
|
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
|
||||||
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
|
parser.add_argument("--max_source_length", type=int, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
|
||||||
|
)
|
||||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||||
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
|
||||||
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -113,24 +117,23 @@ def run_generate():
|
|||||||
parser.add_argument("--save_source", action="store_true")
|
parser.add_argument("--save_source", action="store_true")
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
parsed = parse_numeric_cl_kwargs(rest)
|
generate_kwargs = parse_numeric_cl_kwargs(rest)
|
||||||
if parsed:
|
if generate_kwargs:
|
||||||
print(f"parsed the following generate kwargs: {parsed}")
|
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
||||||
Path(args.save_dir).mkdir(exist_ok=True)
|
Path(args.save_dir).mkdir(exist_ok=True)
|
||||||
if args.reference_path is None and Path(args.score_path).exists():
|
|
||||||
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
|
|
||||||
eval_data_dir(
|
eval_data_dir(
|
||||||
args.input_path,
|
args.input_path,
|
||||||
args.save_dir,
|
args.save_dir,
|
||||||
args.model_name,
|
args.model_name,
|
||||||
prefix=args.prefix,
|
type_path=args.type_path,
|
||||||
batch_size=args.bs,
|
batch_size=args.bs,
|
||||||
fp16=args.fp16,
|
fp16=args.fp16,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
local_rank=args.local_rank,
|
local_rank=args.local_rank,
|
||||||
n_obs=args.n_obs,
|
n_obs=args.n_obs,
|
||||||
save_source=args.save_source,
|
save_source=args.save_source,
|
||||||
**parsed,
|
max_source_length=args.max_source_length,
|
||||||
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||||||
self.max_target_length = max_target_length
|
self.max_target_length = max_target_length
|
||||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.prefix = prefix
|
self.prefix = prefix if prefix is not None else ""
|
||||||
|
|
||||||
if n_obs is not None:
|
if n_obs is not None:
|
||||||
self.src_lens = self.src_lens[:n_obs]
|
self.src_lens = self.src_lens[:n_obs]
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
|||||||
Reference in New Issue
Block a user