Fixing related to issue #83.
This commit is contained in:
@@ -423,6 +423,12 @@ def main():
|
|||||||
"mrpc": MrpcProcessor,
|
"mrpc": MrpcProcessor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
num_labels_task = {
|
||||||
|
"cola": 2,
|
||||||
|
"mnli": 3,
|
||||||
|
"mrpc": 2,
|
||||||
|
}
|
||||||
|
|
||||||
if args.local_rank == -1 or args.no_cuda:
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
n_gpu = torch.cuda.device_count()
|
n_gpu = torch.cuda.device_count()
|
||||||
@@ -461,6 +467,7 @@ def main():
|
|||||||
raise ValueError("Task not found: %s" % (task_name))
|
raise ValueError("Task not found: %s" % (task_name))
|
||||||
|
|
||||||
processor = processors[task_name]()
|
processor = processors[task_name]()
|
||||||
|
num_labels = num_labels_task[task_name]
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
@@ -474,7 +481,8 @@ def main():
|
|||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||||
|
num_labels = num_labels)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user