Implement fine-tuning BERT on CoNLL-2003 named entity recognition task
This commit is contained in:
committed by
thomwolf
parent
5adb39e757
commit
383ef96747
@@ -55,7 +55,7 @@ def set_seed(args):
|
|||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
def train(args, train_dataset, model, tokenizer, pad_token_label_id):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
tb_writer = SummaryWriter()
|
tb_writer = SummaryWriter()
|
||||||
@@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id)
|
results = evaluate(args, model, tokenizer, pad_token_label_id)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
@@ -160,7 +160,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
model_to_save = model.module if hasattr(model,
|
||||||
|
"module") else model # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
@@ -178,8 +179,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""):
|
def evaluate(args, model, tokenizer, pad_token_label_id, prefix=""):
|
||||||
eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)
|
eval_dataset = load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=True)
|
||||||
|
|
||||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
@@ -219,7 +220,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
|||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
preds = np.argmax(preds, axis=2)
|
preds = np.argmax(preds, axis=2)
|
||||||
|
|
||||||
label_map = {i: label for i, label in enumerate(labels)}
|
label_map = {i: label for i, label in enumerate(get_labels())}
|
||||||
|
|
||||||
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
||||||
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
||||||
@@ -241,15 +242,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
|||||||
for key in sorted(results.keys()):
|
for key in sorted(results.keys()):
|
||||||
logger.info(" %s = %s", key, str(results[key]))
|
logger.info(" %s = %s", key, str(results[key]))
|
||||||
|
|
||||||
return results, preds_list
|
return results
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
def load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=False):
|
||||||
if args.local_rank not in [-1, 0] and not evaluate:
|
if args.local_rank not in [-1, 0] and not evaluate:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format(mode,
|
cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format("dev" if evaluate else "train",
|
||||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
str(args.max_seq_length)))
|
str(args.max_seq_length)))
|
||||||
if os.path.exists(cached_features_file):
|
if os.path.exists(cached_features_file):
|
||||||
@@ -257,8 +258,9 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
|||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
examples = read_examples_from_file(args.data_dir, mode)
|
label_list = get_labels()
|
||||||
features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer,
|
examples = read_examples_from_file(args.data_dir, evaluate=evaluate)
|
||||||
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer,
|
||||||
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
||||||
# xlnet has a cls token at the end
|
# xlnet has a cls token at the end
|
||||||
cls_token=tokenizer.cls_token,
|
cls_token=tokenizer.cls_token,
|
||||||
@@ -303,8 +305,6 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
help="The output directory where the model predictions and checkpoints will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--labels", default="", type=str,
|
|
||||||
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
help="Pretrained config name or path if not the same as model_name")
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||||
@@ -318,8 +318,6 @@ def main():
|
|||||||
help="Whether to run training.")
|
help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true",
|
parser.add_argument("--do_eval", action="store_true",
|
||||||
help="Whether to run eval on the dev set.")
|
help="Whether to run eval on the dev set.")
|
||||||
parser.add_argument("--do_predict", action="store_true",
|
|
||||||
help="Whether to run predictions on the test set.")
|
|
||||||
parser.add_argument("--evaluate_during_training", action="store_true",
|
parser.add_argument("--evaluate_during_training", action="store_true",
|
||||||
help="Whether to run evaluation during training at each logging step.")
|
help="Whether to run evaluation during training at each logging step.")
|
||||||
parser.add_argument("--do_lower_case", action="store_true",
|
parser.add_argument("--do_lower_case", action="store_true",
|
||||||
@@ -408,8 +406,8 @@ def main():
|
|||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
# Prepare CONLL-2003 task
|
# Prepare CONLL-2003 task
|
||||||
labels = get_labels(args.labels)
|
label_list = get_labels()
|
||||||
num_labels = len(labels)
|
num_labels = len(label_list)
|
||||||
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
|
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
|
||||||
pad_token_label_id = CrossEntropyLoss().ignore_index
|
pad_token_label_id = CrossEntropyLoss().ignore_index
|
||||||
|
|
||||||
@@ -435,8 +433,8 @@ def main():
|
|||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train")
|
train_dataset = load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=False)
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer, pad_token_label_id)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
@@ -468,7 +466,7 @@ def main():
|
|||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
|
result = evaluate(args, model, tokenizer, pad_token_label_id, prefix=global_step)
|
||||||
if global_step:
|
if global_step:
|
||||||
result = {"{}_{}".format(global_step, k): v for k, v in result.items()}
|
result = {"{}_{}".format(global_step, k): v for k, v in result.items()}
|
||||||
results.update(result)
|
results.update(result)
|
||||||
@@ -477,32 +475,6 @@ def main():
|
|||||||
for key in sorted(results.keys()):
|
for key in sorted(results.keys()):
|
||||||
writer.write("{} = {}\n".format(key, str(results[key])))
|
writer.write("{} = {}\n".format(key, str(results[key])))
|
||||||
|
|
||||||
if args.do_predict and args.local_rank in [-1, 0]:
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
|
||||||
model.to(args.device)
|
|
||||||
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
|
|
||||||
# Save results
|
|
||||||
output_test_results_file = os.path.join(args.output_dir, "test_results.txt")
|
|
||||||
with open(output_test_results_file, "w") as writer:
|
|
||||||
for key in sorted(result.keys()):
|
|
||||||
writer.write("{} = {}\n".format(key, str(result[key])))
|
|
||||||
# Save predictions
|
|
||||||
output_test_predictions_file = os.path.join(args.output_dir, "test_predictions.txt")
|
|
||||||
with open(output_test_predictions_file, "w") as writer:
|
|
||||||
with open(os.path.join(args.data_dir, "test.txt"), "r") as f:
|
|
||||||
example_id = 0
|
|
||||||
for line in f:
|
|
||||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
|
||||||
writer.write(line)
|
|
||||||
if not predictions[example_id]:
|
|
||||||
example_id += 1
|
|
||||||
elif predictions[example_id]:
|
|
||||||
output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n"
|
|
||||||
writer.write(output_line)
|
|
||||||
else:
|
|
||||||
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -51,8 +51,13 @@ class InputFeatures(object):
|
|||||||
self.label_ids = label_ids
|
self.label_ids = label_ids
|
||||||
|
|
||||||
|
|
||||||
def read_examples_from_file(data_dir, mode):
|
def read_examples_from_file(data_dir, evaluate=False):
|
||||||
file_path = os.path.join(data_dir, "{}.txt".format(mode))
|
if evaluate:
|
||||||
|
file_path = os.path.join(data_dir, "dev.txt")
|
||||||
|
guid_prefix = "dev"
|
||||||
|
else:
|
||||||
|
file_path = os.path.join(data_dir, "train.txt")
|
||||||
|
guid_prefix = "train"
|
||||||
guid_index = 1
|
guid_index = 1
|
||||||
examples = []
|
examples = []
|
||||||
with open(file_path, encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
@@ -61,7 +66,7 @@ def read_examples_from_file(data_dir, mode):
|
|||||||
for line in f:
|
for line in f:
|
||||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||||
if words:
|
if words:
|
||||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
|
examples.append(InputExample(guid="{}-{}".format(guid_prefix, guid_index),
|
||||||
words=words,
|
words=words,
|
||||||
labels=labels))
|
labels=labels))
|
||||||
guid_index += 1
|
guid_index += 1
|
||||||
@@ -70,13 +75,9 @@ def read_examples_from_file(data_dir, mode):
|
|||||||
else:
|
else:
|
||||||
splits = line.split(" ")
|
splits = line.split(" ")
|
||||||
words.append(splits[0])
|
words.append(splits[0])
|
||||||
if len(splits) > 1:
|
labels.append(splits[-1][:-1])
|
||||||
labels.append(splits[-1].replace("\n", ""))
|
|
||||||
else:
|
|
||||||
# Examples could have no label for mode = "test"
|
|
||||||
labels.append("O")
|
|
||||||
if words:
|
if words:
|
||||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
|
examples.append(InputExample(guid="%s-%d".format(guid_prefix, guid_index),
|
||||||
words=words,
|
words=words,
|
||||||
labels=labels))
|
labels=labels))
|
||||||
return examples
|
return examples
|
||||||
@@ -201,12 +202,5 @@ def convert_examples_to_features(examples,
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def get_labels(path):
|
def get_labels():
|
||||||
if path:
|
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||||
with open(path, "r") as f:
|
|
||||||
labels = f.read().splitlines()
|
|
||||||
if "O" not in labels:
|
|
||||||
labels = ["O"] + labels
|
|
||||||
return labels
|
|
||||||
else:
|
|
||||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user