add test related code
This commit is contained in:
@@ -126,6 +126,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
|
best_dev_acc, best_dev_loss = 0.0, 99999999999.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||||
@@ -167,6 +168,13 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||||
|
if results["eval_loss"] < best_dev_loss:
|
||||||
|
best_dev_acc = results["eval_acc"]
|
||||||
|
best_dev_loss = results["eval_loss"]
|
||||||
|
results_test = evaluate(args, model, tokenizer, test=True)
|
||||||
|
for key, value in results_test.items():
|
||||||
|
tb_writer.add_scalar('test_{}'.format(key), value, global_step)
|
||||||
|
logger.info("test acc: %s, loss: %s, global steps: %s", str(results_test['eval_acc']), str(results_test['eval_loss']), str(global_step))
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||||
logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step))
|
logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step))
|
||||||
@@ -196,14 +204,14 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix=""):
|
def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||||
eval_task_names = (args.task_name,)
|
eval_task_names = (args.task_name,)
|
||||||
eval_outputs_dirs = (args.output_dir,)
|
eval_outputs_dirs = (args.output_dir,)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=not test, test=test)
|
||||||
|
|
||||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(eval_output_dir)
|
os.makedirs(eval_output_dir)
|
||||||
@@ -251,7 +259,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
||||||
|
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
logger.info("***** Eval results {} *****".format(prefix))
|
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
|
||||||
writer.write("model =%s\n" % str(args.model_name_or_path))
|
writer.write("model =%s\n" % str(args.model_name_or_path))
|
||||||
writer.write("total batch size=%d\n" % (args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
|
writer.write("total batch size=%d\n" % (args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
|
||||||
(torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
|
(torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
|
||||||
@@ -264,14 +272,21 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
processor = processors[task]()
|
processor = processors[task]()
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
|
if evaluate:
|
||||||
|
cached_mode = 'dev'
|
||||||
|
elif test:
|
||||||
|
cached_mode = 'test'
|
||||||
|
else:
|
||||||
|
cached_mode = 'train'
|
||||||
|
assert (evaluate == True and test == True) == False
|
||||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||||
'dev' if evaluate else 'train',
|
cached_mode,
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task)))
|
str(task)))
|
||||||
@@ -281,7 +296,12 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
if evaluate:
|
||||||
|
examples = processor.get_dev_examples(args.data_dir)
|
||||||
|
elif test:
|
||||||
|
examples = processor.get_test_examples(args.data_dir)
|
||||||
|
else:
|
||||||
|
examples = processor.get_train_examples(args.data_dir)
|
||||||
logger.info("Training number: %s", str(len(examples)))
|
logger.info("Training number: %s", str(len(examples)))
|
||||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer,
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer,
|
||||||
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
||||||
@@ -337,6 +357,7 @@ def main():
|
|||||||
help="Whether to run training.")
|
help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action='store_true',
|
||||||
help="Whether to run eval on the dev set.")
|
help="Whether to run eval on the dev set.")
|
||||||
|
parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set')
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||||
help="Rul evaluation during training at each logging step.")
|
help="Rul evaluation during training at each logging step.")
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument("--do_lower_case", action='store_true',
|
||||||
@@ -494,6 +515,22 @@ def main():
|
|||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
|
if args.do_test and args.local_rank in [-1, 0]:
|
||||||
|
if not args.do_train:
|
||||||
|
args.output_dir = args.model_name_or_path
|
||||||
|
checkpoints = [args.output_dir]
|
||||||
|
if args.eval_all_checkpoints:
|
||||||
|
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||||
|
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
for checkpoint in checkpoints:
|
||||||
|
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||||
|
model = model_class.from_pretrained(checkpoint)
|
||||||
|
model.to(args.device)
|
||||||
|
result = evaluate(args, model, tokenizer, prefix=global_step, test=True)
|
||||||
|
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||||
|
results.update(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -83,6 +83,10 @@ class DataProcessor(object):
|
|||||||
"""Gets a collection of `InputExample`s for the dev set."""
|
"""Gets a collection of `InputExample`s for the dev set."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""Gets a collection of `InputExample`s for the dev set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""Gets the list of labels for this data set."""
|
"""Gets the list of labels for this data set."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -109,6 +113,15 @@ class RaceProcessor(DataProcessor):
|
|||||||
middle = self._read_txt(middle)
|
middle = self._read_txt(middle)
|
||||||
return self._create_examples(high + middle, 'dev')
|
return self._create_examples(high + middle, 'dev')
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
logger.info("LOOKING AT {} test".format(data_dir))
|
||||||
|
high = os.path.join(data_dir, 'test/high')
|
||||||
|
middle = os.path.join(data_dir, 'test/middle')
|
||||||
|
high = self._read_txt(high)
|
||||||
|
middle = self._read_txt(middle)
|
||||||
|
return self._create_examples(high + middle, 'test')
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1", "2", "3"]
|
return ["0", "1", "2", "3"]
|
||||||
@@ -157,6 +170,11 @@ class SwagProcessor(DataProcessor):
|
|||||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
logger.info("LOOKING AT {} test".format(data_dir))
|
||||||
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1", "2", "3"]
|
return ["0", "1", "2", "3"]
|
||||||
@@ -207,6 +225,10 @@ class ArcProcessor(DataProcessor):
|
|||||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||||
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev")
|
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
logger.info("LOOKING AT {} test".format(data_dir))
|
||||||
|
return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1", "2", "3"]
|
return ["0", "1", "2", "3"]
|
||||||
|
|||||||
Reference in New Issue
Block a user