Labels are now added to model config under id2label and label2id (#2945)

This commit is contained in:
Martin Malmsten
2020-02-21 14:53:05 +01:00
committed by GitHub
parent 53ce3854a1
commit 4452b44b90

View File

@@ -586,6 +586,8 @@ def main():
config = config_class.from_pretrained( config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels, num_labels=num_labels,
id2label={str(i): label for i, label in enumerate(labels)},
label2id={label: i for i, label in enumerate(labels)},
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir if args.cache_dir else None,
) )
tokenizer = tokenizer_class.from_pretrained( tokenizer = tokenizer_class.from_pretrained(