[examples] Use AutoModels in more examples
This commit is contained in:
@@ -31,6 +31,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
@@ -38,7 +39,6 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
|
||||
|
||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||
|
||||
@@ -13,16 +13,11 @@ from seqeval import metrics
|
||||
|
||||
from transformers import (
|
||||
TF2_WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertTokenizer,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GradientAccumulator,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
TFBertForTokenClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFAutoModelForTokenClassification,
|
||||
create_optimizer,
|
||||
)
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
@@ -34,22 +29,17 @@ except ImportError:
|
||||
from fastprogress.fastprogress import master_bar, progress_bar
|
||||
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
|
||||
)
|
||||
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
|
||||
}
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
||||
)
|
||||
|
||||
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_TYPES))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_name_or_path",
|
||||
@@ -509,8 +499,7 @@ def main(_):
|
||||
labels = get_labels(args["labels"])
|
||||
num_labels = len(labels) + 1
|
||||
pad_token_label_id = 0
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
||||
num_labels=num_labels,
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
@@ -520,14 +509,14 @@ def main(_):
|
||||
|
||||
# Training
|
||||
if args["do_train"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
||||
do_lower_case=args["do_lower_case"],
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
)
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(
|
||||
args["model_name_or_path"],
|
||||
from_pt=bool(".bin" in args["model_name_or_path"]),
|
||||
config=config,
|
||||
@@ -562,7 +551,7 @@ def main(_):
|
||||
|
||||
# Evaluation
|
||||
if args["do_eval"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
checkpoints = []
|
||||
results = []
|
||||
|
||||
@@ -584,7 +573,7 @@ def main(_):
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
|
||||
|
||||
y_true, y_pred, eval_loss = evaluate(
|
||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||
@@ -611,8 +600,8 @@ def main(_):
|
||||
writer.write("\n")
|
||||
|
||||
if args["do_predict"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
model = model_class.from_pretrained(args["output_dir"])
|
||||
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
|
||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||
predict_dataset, _ = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
||||
|
||||
Reference in New Issue
Block a user