NER support for Albert in run_ner.py and NerPipeline (#2983)
* * Added support for Albert when fine-tuning for NER * Added support for Albert in NER pipeline * Added command-line options to examples/ner/run_ner.py to better control tokenization * Added class AlbertForTokenClassification * Changed output for NerPipeline to use .convert_ids_to_tokens(...) instead of .decode(...) to better reflect tokens * Added , * Now passes style guide enforcement * Changes from reviews. * Code now passes style enforcement * Added test for AlbertForTokenClassification * Added test for AlbertForTokenClassification
This commit is contained in:
@@ -33,6 +33,9 @@ from tqdm import tqdm, trange
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertForTokenClassification,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertForTokenClassification,
|
||||
BertTokenizer,
|
||||
@@ -70,6 +73,7 @@ ALL_MODELS = sum(
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
||||
@@ -77,6 +81,8 @@ MODEL_CLASSES = {
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
|
||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@@ -462,7 +468,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep_accents", action="store_const", const=True, help="Set this flag if model is trained with accents."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strip_accents", action="store_const", const=True, help="Set this flag if model is trained without accents."
|
||||
)
|
||||
parser.add_argument("--use_fast", action="store_const", const=True, help="Set this flag to use fast tokenization.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
@@ -590,10 +602,12 @@ def main():
|
||||
label2id={label: i for i, label in enumerate(labels)},
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
|
||||
logger.info("Tokenizer arguments: %s", tokenizer_args)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
**tokenizer_args,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
@@ -636,7 +650,7 @@ def main():
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
@@ -658,7 +672,7 @@ def main():
|
||||
writer.write("{} = {}\n".format(key, str(results[key])))
|
||||
|
||||
if args.do_predict and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
|
||||
Reference in New Issue
Block a user