[s2s] run_eval.py parses generate_kwargs (#6948)
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
||||
import pickle
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
from typing import Callable, Dict, Iterable, List, Union
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
@@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
npars = len(model_grads)
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
|
||||
|
||||
# CLI Parsing utils
|
||||
|
||||
|
||||
def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
|
||||
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
|
||||
result = {}
|
||||
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
||||
num_pairs = len(unparsed_args) // 2
|
||||
for pair_num in range(num_pairs):
|
||||
i = 2 * pair_num
|
||||
assert unparsed_args[i].startswith("--")
|
||||
try:
|
||||
value = int(unparsed_args[i + 1])
|
||||
except ValueError:
|
||||
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
||||
|
||||
result[unparsed_args[i][2:]] = value
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user