[s2s run_eval] new features (#7109)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -228,6 +228,67 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
|
|||||||
--bs 32
|
--bs 32
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### run_eval tips and tricks
|
||||||
|
|
||||||
|
When using `run_eval.py`, the following features can be useful:
|
||||||
|
|
||||||
|
* if you running the script multiple times and want to make it easier to track what arguments produced that output, use `--dump-args`. Along with the results it will also dump any custom params that were passed to the script. For example if you used: `--num_beams 8 --early_stopping true`, the output will be:
|
||||||
|
```
|
||||||
|
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True}
|
||||||
|
```
|
||||||
|
|
||||||
|
`--info` is an additional argument available for the same purpose of tracking the conditions of the experiment. It's useful to pass things that weren't in the argument list, e.g. a language pair `--info "lang:en-ru"`. But also if you pass `--info` without a value it will fallback to the current date/time string, e.g. `2020-09-13 18:44:43`.
|
||||||
|
|
||||||
|
If using `--dump-args --info`, the output will be:
|
||||||
|
|
||||||
|
```
|
||||||
|
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': '2020-09-13 18:44:43'}
|
||||||
|
```
|
||||||
|
|
||||||
|
If using `--dump-args --info "pair:en-ru chkpt=best`, the output will be:
|
||||||
|
|
||||||
|
```
|
||||||
|
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': 'pair=en-ru chkpt=best'}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
* if you need to perform a parametric search in order to find the best ones that lead to the highest BLEU score, let `run_eval_search.py` to do the searching for you.
|
||||||
|
|
||||||
|
The script accepts the exact same arguments as `run_eval.py`, plus an additional argument `--search`. The value of `--search` is parsed, reformatted and fed to ``run_eval.py`` as additional args.
|
||||||
|
|
||||||
|
The format for the `--search` value is a simple string with hparams and colon separated values to try, e.g.:
|
||||||
|
```
|
||||||
|
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
|
||||||
|
```
|
||||||
|
which will generate `12` `(2*3*2)` searches for a product of each hparam. For example the example that was just used will invoke `run_eval.py` repeatedly with:
|
||||||
|
|
||||||
|
```
|
||||||
|
--num_beams 5 --length_penalty 0.8 --early_stopping true
|
||||||
|
--num_beams 5 --length_penalty 0.8 --early_stopping false
|
||||||
|
[...]
|
||||||
|
--num_beams 10 --length_penalty 1.2 --early_stopping false
|
||||||
|
```
|
||||||
|
|
||||||
|
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
|
||||||
|
|
||||||
|
```
|
||||||
|
bleu | num_beams | length_penalty | early_stopping
|
||||||
|
----- | --------- | -------------- | --------------
|
||||||
|
26.71 | 5 | 1.1 | 1
|
||||||
|
26.66 | 5 | 0.9 | 1
|
||||||
|
26.66 | 5 | 0.9 | 0
|
||||||
|
26.41 | 5 | 1.1 | 0
|
||||||
|
21.94 | 1 | 0.9 | 1
|
||||||
|
21.94 | 1 | 0.9 | 0
|
||||||
|
21.94 | 1 | 1.1 | 1
|
||||||
|
21.94 | 1 | 1.1 | 0
|
||||||
|
|
||||||
|
Best score args:
|
||||||
|
stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --reference_path data/en-ru/val.target --score_path data/en-ru/test_bleu.json --bs 8 --task translation --num_beams 5 --length_penalty 1.1 --early_stopping True
|
||||||
|
```
|
||||||
|
|
||||||
|
If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other.
|
||||||
|
|
||||||
|
|
||||||
### DistilBART
|
### DistilBART
|
||||||

|

|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
@@ -15,9 +16,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
|
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
|
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
@@ -72,7 +73,26 @@ def generate_summaries_or_translations(
|
|||||||
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
|
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
|
||||||
|
|
||||||
|
|
||||||
def run_generate():
|
def datetime_now():
|
||||||
|
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
|
def run_generate(verbose=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Takes input text, generates output, and then using reference calculates the BLEU scores.
|
||||||
|
|
||||||
|
The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tuple: ``(scores, params}``
|
||||||
|
- ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}``
|
||||||
|
- ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}``
|
||||||
|
"""
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||||
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
||||||
@@ -89,11 +109,19 @@ def run_generate():
|
|||||||
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
|
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
|
||||||
)
|
)
|
||||||
parser.add_argument("--fp16", action="store_true")
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results")
|
||||||
|
parser.add_argument(
|
||||||
|
"--info",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
const=datetime_now(),
|
||||||
|
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
|
||||||
|
)
|
||||||
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
parsed = parse_numeric_cl_kwargs(rest)
|
parsed_args = parse_numeric_n_bool_cl_kwargs(rest)
|
||||||
if parsed:
|
if parsed_args and verbose:
|
||||||
print(f"parsed the following generate kwargs: {parsed}")
|
print(f"parsed the following generate kwargs: {parsed_args}")
|
||||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||||
if args.n_obs > 0:
|
if args.n_obs > 0:
|
||||||
examples = examples[: args.n_obs]
|
examples = examples[: args.n_obs]
|
||||||
@@ -109,23 +137,35 @@ def run_generate():
|
|||||||
fp16=args.fp16,
|
fp16=args.fp16,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
prefix=args.prefix,
|
prefix=args.prefix,
|
||||||
**parsed,
|
**parsed_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.reference_path is None:
|
if args.reference_path is None:
|
||||||
return
|
return {}
|
||||||
|
|
||||||
# Compute scores
|
# Compute scores
|
||||||
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
|
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
|
||||||
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
||||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
||||||
scores: dict = score_fn(output_lns, reference_lns)
|
scores: dict = score_fn(output_lns, reference_lns)
|
||||||
scores.update(runtime_metrics)
|
scores.update(runtime_metrics)
|
||||||
print(scores)
|
|
||||||
|
if args.dump_args:
|
||||||
|
scores.update(parsed_args)
|
||||||
|
if args.info:
|
||||||
|
scores["info"] = args.info
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(*scores)
|
||||||
|
|
||||||
if args.score_path is not None:
|
if args.score_path is not None:
|
||||||
json.dump(scores, open(args.score_path, "w"))
|
path = args.score_path
|
||||||
|
json.dump(scores, open(path, "w"))
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Usage for MT:
|
# Usage for MT:
|
||||||
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
|
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
|
||||||
run_generate()
|
run_generate(verbose=True)
|
||||||
|
|||||||
139
examples/seq2seq/run_eval_search.py
Normal file
139
examples/seq2seq/run_eval_search.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import argparse
|
||||||
|
import itertools
|
||||||
|
import operator
|
||||||
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .run_eval import datetime_now, run_generate
|
||||||
|
except ImportError:
|
||||||
|
from run_eval import datetime_now, run_generate
|
||||||
|
|
||||||
|
|
||||||
|
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
|
||||||
|
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
|
||||||
|
task_score_names = {
|
||||||
|
"translation": ["bleu"],
|
||||||
|
"translation_en_to_de": ["bleu"],
|
||||||
|
"summarization": ["rouge1", "rouge2", "rougeL"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_search_arg(search):
|
||||||
|
groups = search.split()
|
||||||
|
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
|
||||||
|
entry_names = list(entries.keys())
|
||||||
|
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
|
||||||
|
matrix = [list(x) for x in itertools.product(*sets)]
|
||||||
|
return matrix, entry_names
|
||||||
|
|
||||||
|
|
||||||
|
def run_search():
|
||||||
|
"""
|
||||||
|
Run parametric search over the desired hparam space with help of ``run_eval.py``.
|
||||||
|
|
||||||
|
All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.
|
||||||
|
|
||||||
|
The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
|
||||||
|
```
|
||||||
|
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
|
||||||
|
```
|
||||||
|
which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:
|
||||||
|
|
||||||
|
```
|
||||||
|
--num_beams 5 --length_penalty 0.8 --early_stopping true
|
||||||
|
--num_beams 5 --length_penalty 0.8 --early_stopping false
|
||||||
|
[...]
|
||||||
|
--num_beams 10 --length_penalty 1.2 --early_stopping false
|
||||||
|
```
|
||||||
|
|
||||||
|
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
prog = sys.argv[0]
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--search",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--info",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
const=datetime_now(),
|
||||||
|
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
|
||||||
|
)
|
||||||
|
args, args_main = parser.parse_known_args()
|
||||||
|
# we share some of the args
|
||||||
|
args_main.extend(["--task", args.task])
|
||||||
|
args_normal = [prog] + args_main
|
||||||
|
|
||||||
|
matrix, col_names = parse_search_arg(args.search)
|
||||||
|
col_names[0:0] = task_score_names[args.task] # score cols first
|
||||||
|
col_widths = {col: len(str(col)) for col in col_names}
|
||||||
|
results = []
|
||||||
|
for r in matrix:
|
||||||
|
hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
|
||||||
|
args_exp = " ".join(r).split()
|
||||||
|
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
|
||||||
|
sys.argv = args_normal + args_exp
|
||||||
|
|
||||||
|
# XXX: need to trap CUDA OOM and lower args.bs if that happens and retry
|
||||||
|
|
||||||
|
scores = run_generate(verbose=False)
|
||||||
|
# make sure scores are first in the table
|
||||||
|
result = OrderedDict()
|
||||||
|
for score in task_score_names[args.task]:
|
||||||
|
result[score] = scores[score]
|
||||||
|
result.update(hparams)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# find widest entries
|
||||||
|
for k, v in result.items():
|
||||||
|
l = len(str(v))
|
||||||
|
if l > col_widths[k]:
|
||||||
|
col_widths[k] = l
|
||||||
|
|
||||||
|
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
|
||||||
|
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
|
||||||
|
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
|
||||||
|
for row in results_sorted:
|
||||||
|
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
|
||||||
|
|
||||||
|
best = results_sorted[0]
|
||||||
|
for score in task_score_names[args.task]:
|
||||||
|
del best[score]
|
||||||
|
best_args = [f"--{k} {v}" for k, v in best.items()]
|
||||||
|
dyn_args = ["--bs", str(args.bs)]
|
||||||
|
if args.info:
|
||||||
|
print(f"\nInfo: {args.info}")
|
||||||
|
print("\nBest score args:")
|
||||||
|
print(" ".join(args_main + best_args + dyn_args))
|
||||||
|
|
||||||
|
return results_sorted
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Usage:
|
||||||
|
# [normal-run_eval_search.py cmd plus] \
|
||||||
|
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
|
||||||
|
# $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
|
||||||
|
# --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
|
||||||
|
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
|
||||||
|
run_search()
|
||||||
@@ -23,6 +23,7 @@ from .distillation import distill_main, evaluate_checkpoint
|
|||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
|
from .run_eval_search import run_search
|
||||||
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
@@ -283,7 +284,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||||
def test_run_eval(model):
|
def test_run_eval(model):
|
||||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
@@ -311,6 +312,58 @@ def test_run_eval(model):
|
|||||||
assert Path(output_file_name).exists()
|
assert Path(output_file_name).exists()
|
||||||
os.remove(Path(output_file_name))
|
os.remove(Path(output_file_name))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
|
||||||
|
def test_run_eval_search(model):
|
||||||
|
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||||
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
|
assert not output_file_name.exists()
|
||||||
|
|
||||||
|
text = {
|
||||||
|
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
||||||
|
"de": [
|
||||||
|
"Maschinelles Lernen ist großartig, oder?",
|
||||||
|
"Ich esse gerne Bananen",
|
||||||
|
"Morgen ist wieder ein toller Tag!",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp_dir = Path(tempfile.mkdtemp())
|
||||||
|
score_path = str(tmp_dir / "scores.json")
|
||||||
|
reference_path = str(tmp_dir / "val.target")
|
||||||
|
_dump_articles(input_file_name, text["en"])
|
||||||
|
_dump_articles(reference_path, text["de"])
|
||||||
|
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||||
|
testargs = [
|
||||||
|
"run_eval_search.py",
|
||||||
|
model,
|
||||||
|
str(input_file_name),
|
||||||
|
str(output_file_name),
|
||||||
|
"--score_path",
|
||||||
|
score_path,
|
||||||
|
"--reference_path",
|
||||||
|
reference_path,
|
||||||
|
"--task",
|
||||||
|
task,
|
||||||
|
"--search",
|
||||||
|
"num_beams=1:2 length_penalty=0.9:1.0",
|
||||||
|
]
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
with CaptureStdout() as cs:
|
||||||
|
run_search()
|
||||||
|
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
||||||
|
un_expected_strings = ["Info"]
|
||||||
|
if "translation" in task:
|
||||||
|
expected_strings.append("bleu")
|
||||||
|
else:
|
||||||
|
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
|
||||||
|
for w in expected_strings:
|
||||||
|
assert w in cs.out
|
||||||
|
for w in un_expected_strings:
|
||||||
|
assert w not in cs.out
|
||||||
|
assert Path(output_file_name).exists()
|
||||||
|
os.remove(Path(output_file_name))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["model"],
|
["model"],
|
||||||
|
|||||||
@@ -377,18 +377,26 @@ def assert_not_all_frozen(model):
|
|||||||
# CLI Parsing utils
|
# CLI Parsing utils
|
||||||
|
|
||||||
|
|
||||||
def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
|
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
||||||
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
|
"""
|
||||||
|
Parse an argv list of unspecified command line args to a dict.
|
||||||
|
Assumes all values are either numeric or boolean in the form of true/false.
|
||||||
|
"""
|
||||||
result = {}
|
result = {}
|
||||||
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
||||||
num_pairs = len(unparsed_args) // 2
|
num_pairs = len(unparsed_args) // 2
|
||||||
for pair_num in range(num_pairs):
|
for pair_num in range(num_pairs):
|
||||||
i = 2 * pair_num
|
i = 2 * pair_num
|
||||||
assert unparsed_args[i].startswith("--")
|
assert unparsed_args[i].startswith("--")
|
||||||
try:
|
if unparsed_args[i + 1].lower() == "true":
|
||||||
value = int(unparsed_args[i + 1])
|
value = True
|
||||||
except ValueError:
|
elif unparsed_args[i + 1].lower() == "false":
|
||||||
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
value = False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
value = int(unparsed_args[i + 1])
|
||||||
|
except ValueError:
|
||||||
|
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
||||||
|
|
||||||
result[unparsed_args[i][2:]] = value
|
result[unparsed_args[i][2:]] = value
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user