Update examples/ner/run_ner.py to use AutoModel (#3305)
* Update examples/ner/run_ner.py to use AutoModel * Fix missing code and apply `make style` command
This commit is contained in:
@@ -31,28 +31,15 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
AlbertConfig,
|
AutoConfig,
|
||||||
AlbertForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AlbertTokenizer,
|
AutoTokenizer,
|
||||||
BertConfig,
|
|
||||||
BertForTokenClassification,
|
|
||||||
BertTokenizer,
|
|
||||||
CamembertConfig,
|
|
||||||
CamembertForTokenClassification,
|
|
||||||
CamembertTokenizer,
|
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForTokenClassification,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
RobertaConfig,
|
|
||||||
RobertaForTokenClassification,
|
|
||||||
RobertaTokenizer,
|
|
||||||
XLMRobertaConfig,
|
|
||||||
XLMRobertaForTokenClassification,
|
|
||||||
XLMRobertaTokenizer,
|
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_auto import MODEL_MAPPING
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||||
|
|
||||||
|
|
||||||
@@ -64,22 +51,8 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
(
|
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
|
||||||
tuple(conf.pretrained_config_archive_map.keys())
|
|
||||||
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
|
|
||||||
),
|
|
||||||
(),
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
|
||||||
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
|
|
||||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
|
||||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
|
||||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
|
||||||
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
|
|
||||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
|
|
||||||
}
|
|
||||||
|
|
||||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||||
|
|
||||||
@@ -411,7 +384,7 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
@@ -594,8 +567,7 @@ def main():
|
|||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config = AutoConfig.from_pretrained(
|
||||||
config = config_class.from_pretrained(
|
|
||||||
args.config_name if args.config_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
id2label={str(i): label for i, label in enumerate(labels)},
|
id2label={str(i): label for i, label in enumerate(labels)},
|
||||||
@@ -604,12 +576,12 @@ def main():
|
|||||||
)
|
)
|
||||||
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
|
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
|
||||||
logger.info("Tokenizer arguments: %s", tokenizer_args)
|
logger.info("Tokenizer arguments: %s", tokenizer_args)
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
**tokenizer_args,
|
**tokenizer_args,
|
||||||
)
|
)
|
||||||
model = model_class.from_pretrained(
|
model = AutoModelForTokenClassification.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
@@ -650,7 +622,7 @@ def main():
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(
|
checkpoints = list(
|
||||||
@@ -660,7 +632,7 @@ def main():
|
|||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = AutoModelForTokenClassification.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
|
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
|
||||||
if global_step:
|
if global_step:
|
||||||
@@ -672,8 +644,8 @@ def main():
|
|||||||
writer.write("{} = {}\n".format(key, str(results[key])))
|
writer.write("{} = {}\n".format(key, str(results[key])))
|
||||||
|
|
||||||
if args.do_predict and args.local_rank in [-1, 0]:
|
if args.do_predict and args.local_rank in [-1, 0]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = AutoModelForTokenClassification.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
|
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||||
# Save results
|
# Save results
|
||||||
|
|||||||
Reference in New Issue
Block a user