From a994bf4076667d6885ee0596c35c90af297ad7b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Ch=C3=A2tel?= Date: Wed, 5 Dec 2018 18:16:30 +0100 Subject: [PATCH] Fixing related to issue #83. --- examples/run_classifier.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 7cfa39dabf..475ab54c96 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -423,6 +423,12 @@ def main(): "mrpc": MrpcProcessor, } + num_labels_task = { + "cola": 2, + "mnli": 3, + "mrpc": 2, + } + if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() @@ -461,6 +467,7 @@ def main(): raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() + num_labels = num_labels_task[task_name] label_list = processor.get_labels() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) @@ -474,7 +481,8 @@ def main(): # Prepare model model = BertForSequenceClassification.from_pretrained(args.bert_model, - cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) + cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), + num_labels = num_labels) if args.fp16: model.half() model.to(device)