Update label2id in the model config for run_glue (#13334)
This commit is contained in:
@@ -380,6 +380,9 @@ def main():
|
|||||||
if label_to_id is not None:
|
if label_to_id is not None:
|
||||||
model.config.label2id = label_to_id
|
model.config.label2id = label_to_id
|
||||||
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
elif data_args.task_name is not None and not is_regression:
|
||||||
|
model.config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -288,6 +288,9 @@ def main():
|
|||||||
if label_to_id is not None:
|
if label_to_id is not None:
|
||||||
model.config.label2id = label_to_id
|
model.config.label2id = label_to_id
|
||||||
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
elif args.task_name is not None and not is_regression:
|
||||||
|
model.config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
|
||||||
padding = "max_length" if args.pad_to_max_length else False
|
padding = "max_length" if args.pad_to_max_length else False
|
||||||
|
|
||||||
|
|||||||
@@ -355,6 +355,9 @@ def main():
|
|||||||
if label_to_id is not None:
|
if label_to_id is not None:
|
||||||
config.label2id = label_to_id
|
config.label2id = label_to_id
|
||||||
config.id2label = {id: label for label, id in config.label2id.items()}
|
config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
elif data_args.task_name is not None and not is_regression:
|
||||||
|
config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user