From a8e3336a850e856188350a93e67d77c07c85b8af Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 23 Mar 2020 19:30:19 -0400 Subject: [PATCH] [examples] Use AutoModels in more examples --- examples/ner/run_ner.py | 3 +- examples/ner/run_tf_ner.py | 41 ++++------ examples/run_glue.py | 74 ++++--------------- examples/run_language_modeling.py | 57 +++++--------- examples/run_squad.py | 59 ++++----------- src/transformers/__init__.py | 12 +++ .../adding_a_new_example_script/run_xxx.py | 43 ++++------- 7 files changed, 90 insertions(+), 199 deletions(-) diff --git a/examples/ner/run_ner.py b/examples/ner/run_ner.py index c2dfe6856d..818ff91136 100644 --- a/examples/ner/run_ner.py +++ b/examples/ner/run_ner.py @@ -31,6 +31,7 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from transformers import ( + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, WEIGHTS_NAME, AdamW, AutoConfig, @@ -38,7 +39,6 @@ from transformers import ( AutoTokenizer, get_linear_schedule_with_warmup, ) -from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file @@ -52,6 +52,7 @@ logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ()) TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"] diff --git a/examples/ner/run_tf_ner.py b/examples/ner/run_tf_ner.py index ef970d8390..0a607ff662 100644 --- a/examples/ner/run_tf_ner.py +++ b/examples/ner/run_tf_ner.py @@ -13,16 +13,11 @@ from seqeval import metrics from transformers import ( TF2_WEIGHTS_NAME, - BertConfig, - BertTokenizer, - DistilBertConfig, - DistilBertTokenizer, + TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + AutoConfig, + AutoTokenizer, GradientAccumulator, - RobertaConfig, - RobertaTokenizer, - TFBertForTokenClassification, - TFDistilBertForTokenClassification, - TFRobertaForTokenClassification, + TFAutoModelForTokenClassification, create_optimizer, ) from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file @@ -34,22 +29,17 @@ except ImportError: from fastprogress.fastprogress import master_bar, progress_bar -ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), () -) +MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) -MODEL_CLASSES = { - "bert": (BertConfig, TFBertForTokenClassification, BertTokenizer), - "roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer), - "distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer), -} +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),) flags.DEFINE_string( "data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task." ) -flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) +flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_TYPES)) flags.DEFINE_string( "model_name_or_path", @@ -509,8 +499,7 @@ def main(_): labels = get_labels(args["labels"]) num_labels = len(labels) + 1 pad_token_label_id = 0 - config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]] - config = config_class.from_pretrained( + config = AutoConfig.from_pretrained( args["config_name"] if args["config_name"] else args["model_name_or_path"], num_labels=num_labels, cache_dir=args["cache_dir"] if args["cache_dir"] else None, @@ -520,14 +509,14 @@ def main(_): # Training if args["do_train"]: - tokenizer = tokenizer_class.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"], do_lower_case=args["do_lower_case"], cache_dir=args["cache_dir"] if args["cache_dir"] else None, ) with strategy.scope(): - model = model_class.from_pretrained( + model = TFAutoModelForTokenClassification.from_pretrained( args["model_name_or_path"], from_pt=bool(".bin" in args["model_name_or_path"]), config=config, @@ -562,7 +551,7 @@ def main(_): # Evaluation if args["do_eval"]: - tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"]) + tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"]) checkpoints = [] results = [] @@ -584,7 +573,7 @@ def main(_): global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final" with strategy.scope(): - model = model_class.from_pretrained(checkpoint) + model = TFAutoModelForTokenClassification.from_pretrained(checkpoint) y_true, y_pred, eval_loss = evaluate( args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev" @@ -611,8 +600,8 @@ def main(_): writer.write("\n") if args["do_predict"]: - tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"]) - model = model_class.from_pretrained(args["output_dir"]) + tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"]) + model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"]) eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"] predict_dataset, _ = load_and_cache_examples( args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test" diff --git a/examples/run_glue.py b/examples/run_glue.py index f8b17978eb..72fdc2b497 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -30,32 +30,12 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from transformers import ( + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, WEIGHTS_NAME, AdamW, - AlbertConfig, - AlbertForSequenceClassification, - AlbertTokenizer, - BertConfig, - BertForSequenceClassification, - BertTokenizer, - DistilBertConfig, - DistilBertForSequenceClassification, - DistilBertTokenizer, - FlaubertConfig, - FlaubertForSequenceClassification, - FlaubertTokenizer, - RobertaConfig, - RobertaForSequenceClassification, - RobertaTokenizer, - XLMConfig, - XLMForSequenceClassification, - XLMRobertaConfig, - XLMRobertaForSequenceClassification, - XLMRobertaTokenizer, - XLMTokenizer, - XLNetConfig, - XLNetForSequenceClassification, - XLNetTokenizer, + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, get_linear_schedule_with_warmup, ) from transformers import glue_compute_metrics as compute_metrics @@ -72,33 +52,10 @@ except ImportError: logger = logging.getLogger(__name__) -ALL_MODELS = sum( - ( - tuple(conf.pretrained_config_archive_map.keys()) - for conf in ( - BertConfig, - XLNetConfig, - XLMConfig, - RobertaConfig, - DistilBertConfig, - AlbertConfig, - XLMRobertaConfig, - FlaubertConfig, - ) - ), - (), -) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) -MODEL_CLASSES = { - "bert": (BertConfig, BertForSequenceClassification, BertTokenizer), - "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), - "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer), - "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), - "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), - "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), - "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer), - "flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer), -} +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),) def set_seed(args): @@ -442,7 +399,7 @@ def main(): default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_TYPES), ) parser.add_argument( "--model_name_or_path", @@ -622,19 +579,18 @@ def main(): torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained( + config = AutoConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, cache_dir=args.cache_dir if args.cache_dir else None, ) - tokenizer = tokenizer_class.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) - model = model_class.from_pretrained( + model = AutoModelForSequenceClassification.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -673,14 +629,14 @@ def main(): torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(args.output_dir) - tokenizer = tokenizer_class.from_pretrained(args.output_dir) + model = AutoModelForSequenceClassification.from_pretrained(args.output_dir) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir) model.to(args.device) # Evaluation results = {} if args.do_eval and args.local_rank in [-1, 0]: - tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( @@ -692,7 +648,7 @@ def main(): global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" - model = model_class.from_pretrained(checkpoint) + model = AutoModelForSequenceClassification.from_pretrained(checkpoint) model.to(args.device) result = evaluate(args, model, tokenizer, prefix=prefix) result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) diff --git a/examples/run_language_modeling.py b/examples/run_language_modeling.py index c66cc8978f..fa462c65ca 100644 --- a/examples/run_language_modeling.py +++ b/examples/run_language_modeling.py @@ -38,28 +38,15 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from transformers import ( + CONFIG_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, WEIGHTS_NAME, AdamW, - BertConfig, - BertForMaskedLM, - BertTokenizer, - CamembertConfig, - CamembertForMaskedLM, - CamembertTokenizer, - DistilBertConfig, - DistilBertForMaskedLM, - DistilBertTokenizer, - GPT2Config, - GPT2LMHeadModel, - GPT2Tokenizer, - OpenAIGPTConfig, - OpenAIGPTLMHeadModel, - OpenAIGPTTokenizer, + AutoConfig, + AutoModelWithLMHead, + AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, - RobertaConfig, - RobertaForMaskedLM, - RobertaTokenizer, get_linear_schedule_with_warmup, ) @@ -73,14 +60,8 @@ except ImportError: logger = logging.getLogger(__name__) -MODEL_CLASSES = { - "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), - "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - "bert": (BertConfig, BertForMaskedLM, BertTokenizer), - "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), - "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), - "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer), -} +MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) class TextDataset(Dataset): @@ -693,23 +674,21 @@ def main(): if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - if args.config_name: - config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir) + config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir) elif args.model_name_or_path: - config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) else: - config = config_class() + config = CONFIG_MAPPING[args.model_type]() if args.tokenizer_name: - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) elif args.model_name_or_path: - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) else: raise ValueError( "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it," - "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__) + "and load it from here, using --tokenizer_name".format(AutoTokenizer.__name__) ) if args.block_size <= 0: @@ -719,7 +698,7 @@ def main(): args.block_size = min(args.block_size, tokenizer.max_len) if args.model_name_or_path: - model = model_class.from_pretrained( + model = AutoModelWithLMHead.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -727,7 +706,7 @@ def main(): ) else: logger.info("Training new model from scratch") - model = model_class(config=config) + model = AutoModelWithLMHead(config=config) model.to(args.device) @@ -768,8 +747,8 @@ def main(): torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(args.output_dir) - tokenizer = tokenizer_class.from_pretrained(args.output_dir) + model = AutoModelWithLMHead.from_pretrained(args.output_dir) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir) model.to(args.device) # Evaluation @@ -786,7 +765,7 @@ def main(): global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" - model = model_class.from_pretrained(checkpoint) + model = AutoModelWithLMHead.from_pretrained(checkpoint) model.to(args.device) result = evaluate(args, model, tokenizer, prefix=prefix) result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) diff --git a/examples/run_squad.py b/examples/run_squad.py index 523093e1bb..404ab72311 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -30,29 +30,12 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from transformers import ( + MODEL_FOR_QUESTION_ANSWERING_MAPPING, WEIGHTS_NAME, AdamW, - AlbertConfig, - AlbertForQuestionAnswering, - AlbertTokenizer, - BertConfig, - BertForQuestionAnswering, - BertTokenizer, - CamembertConfig, - CamembertForQuestionAnswering, - CamembertTokenizer, - DistilBertConfig, - DistilBertForQuestionAnswering, - DistilBertTokenizer, - RobertaConfig, - RobertaForQuestionAnswering, - RobertaTokenizer, - XLMConfig, - XLMForQuestionAnswering, - XLMTokenizer, - XLNetConfig, - XLNetForQuestionAnswering, - XLNetTokenizer, + AutoConfig, + AutoModelForQuestionAnswering, + AutoTokenizer, get_linear_schedule_with_warmup, squad_convert_examples_to_features, ) @@ -72,23 +55,10 @@ except ImportError: logger = logging.getLogger(__name__) -ALL_MODELS = sum( - ( - tuple(conf.pretrained_config_archive_map.keys()) - for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig) - ), - (), -) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) -MODEL_CLASSES = { - "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), - "camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer), - "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer), - "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), - "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), - "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), - "albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer), -} +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),) def set_seed(args): @@ -513,7 +483,7 @@ def main(): default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_TYPES), ) parser.add_argument( "--model_name_or_path", @@ -757,17 +727,16 @@ def main(): torch.distributed.barrier() args.model_type = args.model_type.lower() - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained( + config = AutoConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) - tokenizer = tokenizer_class.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) - model = model_class.from_pretrained( + model = AutoModelForQuestionAnswering.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -817,8 +786,8 @@ def main(): torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(args.output_dir) # , force_download=True) - tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) model.to(args.device) # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory @@ -842,7 +811,7 @@ def main(): for checkpoint in checkpoints: # Reload the model global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" - model = model_class.from_pretrained(checkpoint) # , force_download=True) + model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) # , force_download=True) model.to(args.device) # Evaluate diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8ea0354dab..fb45cb2a8e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -158,6 +158,12 @@ if is_torch_available(): AutoModelWithLMHead, AutoModelForTokenClassification, ALL_PRETRAINED_MODEL_ARCHIVE_MAP, + MODEL_MAPPING, + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, ) from .modeling_bert import ( @@ -317,6 +323,12 @@ if is_tf_available(): TFAutoModelWithLMHead, TFAutoModelForTokenClassification, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP, + TF_MODEL_MAPPING, + TF_MODEL_FOR_PRETRAINING_MAPPING, + TF_MODEL_WITH_LM_HEAD_MAPPING, + TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, ) from .modeling_tf_bert import ( diff --git a/templates/adding_a_new_example_script/run_xxx.py b/templates/adding_a_new_example_script/run_xxx.py index 20f4b7360b..a4047c865a 100644 --- a/templates/adding_a_new_example_script/run_xxx.py +++ b/templates/adding_a_new_example_script/run_xxx.py @@ -28,20 +28,12 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from transformers import ( + MODEL_FOR_QUESTION_ANSWERING_MAPPING, WEIGHTS_NAME, AdamW, - BertConfig, - BertForQuestionAnswering, - BertTokenizer, - DistilBertConfig, - DistilBertForQuestionAnswering, - DistilBertTokenizer, - XLMConfig, - XLMForQuestionAnswering, - XLMTokenizer, - XLNetConfig, - XLNetForQuestionAnswering, - XLNetTokenizer, + AutoConfig, + AutoModelForQuestionAnswering, + AutoTokenizer, get_linear_schedule_with_warmup, ) from utils_squad import ( @@ -68,16 +60,10 @@ except ImportError: logger = logging.getLogger(__name__) -ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), () -) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) -MODEL_CLASSES = { - "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), - "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), - "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), - "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), -} +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),) def set_seed(args): @@ -418,7 +404,7 @@ def main(): default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_TYPES), ) parser.add_argument( "--model_name_or_path", @@ -626,17 +612,16 @@ def main(): # download model & vocab args.model_type = args.model_type.lower() - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained( + config = AutoConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) - tokenizer = tokenizer_class.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) - model = model_class.from_pretrained( + model = AutoModelForQuestionAnswering.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -687,8 +672,8 @@ def main(): torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(args.output_dir) - tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) model.to(args.device) # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory @@ -706,7 +691,7 @@ def main(): for checkpoint in checkpoints: # Reload the model global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" - model = model_class.from_pretrained(checkpoint) + model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) model.to(args.device) # Evaluate