updating classification example
This commit is contained in:
@@ -228,10 +228,10 @@ def main():
|
|||||||
|
|
||||||
# Prepare data loader
|
# Prepare data loader
|
||||||
train_examples = processor.get_train_examples(args.data_dir)
|
train_examples = processor.get_train_examples(args.data_dir)
|
||||||
cached_train_features_file = args.data_dir + '_{0}_{1}_{2}'.format(
|
cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(
|
||||||
list(filter(None, args.bert_model.split('/'))).pop(),
|
list(filter(None, args.bert_model.split('/'))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task_name))
|
str(task_name)))
|
||||||
try:
|
try:
|
||||||
with open(cached_train_features_file, "rb") as reader:
|
with open(cached_train_features_file, "rb") as reader:
|
||||||
train_features = pickle.load(reader)
|
train_features = pickle.load(reader)
|
||||||
@@ -311,7 +311,7 @@ def main():
|
|||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# define a new function to compute loss values for both output_modes
|
# define a new function to compute loss values for both output_modes
|
||||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
logits = model(input_ids, segment_ids, input_mask)
|
||||||
|
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
@@ -380,6 +380,22 @@ def main():
|
|||||||
### Evaluation
|
### Evaluation
|
||||||
if args.do_eval:
|
if args.do_eval:
|
||||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
|
cached_train_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
|
||||||
|
list(filter(None, args.bert_model.split('/'))).pop(),
|
||||||
|
str(args.max_seq_length),
|
||||||
|
str(task_name)))
|
||||||
|
try:
|
||||||
|
with open(cached_train_features_file, "rb") as reader:
|
||||||
|
train_features = pickle.load(reader)
|
||||||
|
except:
|
||||||
|
train_features = convert_examples_to_features(
|
||||||
|
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||||
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
|
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||||
|
with open(cached_train_features_file, "wb") as writer:
|
||||||
|
pickle.dump(train_features, writer)
|
||||||
|
|
||||||
|
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||||
logger.info("***** Running evaluation *****")
|
logger.info("***** Running evaluation *****")
|
||||||
@@ -414,7 +430,7 @@ def main():
|
|||||||
label_ids = label_ids.to(device)
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
logits = model(input_ids, segment_ids, input_mask)
|
||||||
|
|
||||||
# create eval loss and other metric required by the task
|
# create eval loss and other metric required by the task
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
|
|||||||
Reference in New Issue
Block a user