[examples] Use AutoModels in more examples
This commit is contained in:
@@ -31,6 +31,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
@@ -38,7 +39,6 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
|
||||
|
||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||
|
||||
@@ -13,16 +13,11 @@ from seqeval import metrics
|
||||
|
||||
from transformers import (
|
||||
TF2_WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertTokenizer,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GradientAccumulator,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
TFBertForTokenClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFAutoModelForTokenClassification,
|
||||
create_optimizer,
|
||||
)
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
@@ -34,22 +29,17 @@ except ImportError:
|
||||
from fastprogress.fastprogress import master_bar, progress_bar
|
||||
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
|
||||
)
|
||||
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
|
||||
}
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
||||
)
|
||||
|
||||
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_TYPES))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_name_or_path",
|
||||
@@ -509,8 +499,7 @@ def main(_):
|
||||
labels = get_labels(args["labels"])
|
||||
num_labels = len(labels) + 1
|
||||
pad_token_label_id = 0
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
||||
num_labels=num_labels,
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
@@ -520,14 +509,14 @@ def main(_):
|
||||
|
||||
# Training
|
||||
if args["do_train"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
||||
do_lower_case=args["do_lower_case"],
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
)
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(
|
||||
args["model_name_or_path"],
|
||||
from_pt=bool(".bin" in args["model_name_or_path"]),
|
||||
config=config,
|
||||
@@ -562,7 +551,7 @@ def main(_):
|
||||
|
||||
# Evaluation
|
||||
if args["do_eval"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
checkpoints = []
|
||||
results = []
|
||||
|
||||
@@ -584,7 +573,7 @@ def main(_):
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
|
||||
|
||||
y_true, y_pred, eval_loss = evaluate(
|
||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||
@@ -611,8 +600,8 @@ def main(_):
|
||||
writer.write("\n")
|
||||
|
||||
if args["do_predict"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
model = model_class.from_pretrained(args["output_dir"])
|
||||
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
|
||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||
predict_dataset, _ = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
||||
|
||||
@@ -30,32 +30,12 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
FlaubertConfig,
|
||||
FlaubertForSequenceClassification,
|
||||
FlaubertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMForSequenceClassification,
|
||||
XLMRobertaConfig,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaTokenizer,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
@@ -72,33 +52,10 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (
|
||||
BertConfig,
|
||||
XLNetConfig,
|
||||
XLMConfig,
|
||||
RobertaConfig,
|
||||
DistilBertConfig,
|
||||
AlbertConfig,
|
||||
XLMRobertaConfig,
|
||||
FlaubertConfig,
|
||||
)
|
||||
),
|
||||
(),
|
||||
)
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
||||
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
|
||||
}
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
@@ -442,7 +399,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
@@ -622,19 +579,18 @@ def main():
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
@@ -673,14 +629,14 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
@@ -692,7 +648,7 @@ def main():
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
|
||||
@@ -38,28 +38,15 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertTokenizer,
|
||||
CamembertConfig,
|
||||
CamembertForMaskedLM,
|
||||
CamembertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertTokenizer,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
@@ -73,14 +60,8 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
||||
}
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
@@ -693,23 +674,21 @@ def main():
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
|
||||
if args.config_name:
|
||||
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
||||
elif args.model_name_or_path:
|
||||
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
else:
|
||||
config = config_class()
|
||||
config = CONFIG_MAPPING[args.model_type]()
|
||||
|
||||
if args.tokenizer_name:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
||||
elif args.model_name_or_path:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
|
||||
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
|
||||
"and load it from here, using --tokenizer_name".format(AutoTokenizer.__name__)
|
||||
)
|
||||
|
||||
if args.block_size <= 0:
|
||||
@@ -719,7 +698,7 @@ def main():
|
||||
args.block_size = min(args.block_size, tokenizer.max_len)
|
||||
|
||||
if args.model_name_or_path:
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelWithLMHead.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
@@ -727,7 +706,7 @@ def main():
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
model = model_class(config=config)
|
||||
model = AutoModelWithLMHead(config=config)
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
@@ -768,8 +747,8 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelWithLMHead.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
@@ -786,7 +765,7 @@ def main():
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = AutoModelWithLMHead.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
|
||||
@@ -30,29 +30,12 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertForQuestionAnswering,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
CamembertConfig,
|
||||
CamembertForQuestionAnswering,
|
||||
CamembertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
squad_convert_examples_to_features,
|
||||
)
|
||||
@@ -72,23 +55,10 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
||||
}
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
@@ -513,7 +483,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
@@ -757,17 +727,16 @@ def main():
|
||||
torch.distributed.barrier()
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
@@ -817,8 +786,8 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
@@ -842,7 +811,7 @@ def main():
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint) # , force_download=True)
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) # , force_download=True)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
|
||||
Reference in New Issue
Block a user