Raise exceptions instead of asserts (#13907)
This commit is contained in:
@@ -68,8 +68,10 @@ def format_mrpc(data_dir, path_to_data):
|
|||||||
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
|
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
|
||||||
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
|
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
|
||||||
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
|
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
|
||||||
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
|
if not os.path.isfile(mrpc_train_file):
|
||||||
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
|
raise ValueError(f"Train data not found at {mrpc_train_file}")
|
||||||
|
if not os.path.isfile(mrpc_test_file):
|
||||||
|
raise ValueError(f"Test data not found at {mrpc_test_file}")
|
||||||
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
|
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
|
||||||
|
|
||||||
dev_ids = []
|
dev_ids = []
|
||||||
@@ -118,7 +120,8 @@ def get_tasks(task_names):
|
|||||||
else:
|
else:
|
||||||
tasks = []
|
tasks = []
|
||||||
for task_name in task_names:
|
for task_name in task_names:
|
||||||
assert task_name in TASKS, "Task %s not found!" % task_name
|
if task_name not in TASKS:
|
||||||
|
raise ValueError(f"Task {task_name} not found!")
|
||||||
tasks.append(task_name)
|
tasks.append(task_name)
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user