[s2s] run_eval.py parses generate_kwargs (#6948)

This commit is contained in:
Sam Shleifer
2020-09-04 14:19:31 -04:00
committed by GitHub
parent 6078b12098
commit a4fc0c80b1
3 changed files with 36 additions and 24 deletions

View File

@@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, use_task_specific_params
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -36,7 +36,6 @@ def generate_summaries_or_translations(
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
decoder_start_token_id=None,
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
@@ -59,7 +58,6 @@ def generate_summaries_or_translations(
summaries = model.generate(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
decoder_start_token_id=decoder_start_token_id,
**generate_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
@@ -77,30 +75,20 @@ def run_generate():
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument(
"--score_path",
type=str,
required=False,
default="metrics.json",
help="where to save the rouge score in json format",
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--decoder_start_token_id",
type=int,
default=None,
required=False,
help="Defaults to using config",
)
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
@@ -115,7 +103,7 @@ def run_generate():
device=args.device,
fp16=args.fp16,
task=args.task,
decoder_start_token_id=args.decoder_start_token_id,
**parsed,
)
if args.reference_path is None:
return