[examples] Use AutoModels in more examples
This commit is contained in:
@@ -31,6 +31,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -38,7 +39,6 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||||
|
|
||||||
|
|
||||||
@@ -52,6 +52,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
|
||||||
|
|
||||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||||
|
|||||||
@@ -13,16 +13,11 @@ from seqeval import metrics
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
BertConfig,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertTokenizer,
|
AutoConfig,
|
||||||
DistilBertConfig,
|
AutoTokenizer,
|
||||||
DistilBertTokenizer,
|
|
||||||
GradientAccumulator,
|
GradientAccumulator,
|
||||||
RobertaConfig,
|
TFAutoModelForTokenClassification,
|
||||||
RobertaTokenizer,
|
|
||||||
TFBertForTokenClassification,
|
|
||||||
TFDistilBertForTokenClassification,
|
|
||||||
TFRobertaForTokenClassification,
|
|
||||||
create_optimizer,
|
create_optimizer,
|
||||||
)
|
)
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||||
@@ -34,22 +29,17 @@ except ImportError:
|
|||||||
from fastprogress.fastprogress import master_bar, progress_bar
|
from fastprogress.fastprogress import master_bar, progress_bar
|
||||||
|
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
|
||||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
|
||||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
||||||
)
|
)
|
||||||
|
|
||||||
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_TYPES))
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"model_name_or_path",
|
"model_name_or_path",
|
||||||
@@ -509,8 +499,7 @@ def main(_):
|
|||||||
labels = get_labels(args["labels"])
|
labels = get_labels(args["labels"])
|
||||||
num_labels = len(labels) + 1
|
num_labels = len(labels) + 1
|
||||||
pad_token_label_id = 0
|
pad_token_label_id = 0
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
config = AutoConfig.from_pretrained(
|
||||||
config = config_class.from_pretrained(
|
|
||||||
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||||
@@ -520,14 +509,14 @@ def main(_):
|
|||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args["do_train"]:
|
if args["do_train"]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
||||||
do_lower_case=args["do_lower_case"],
|
do_lower_case=args["do_lower_case"],
|
||||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model = model_class.from_pretrained(
|
model = TFAutoModelForTokenClassification.from_pretrained(
|
||||||
args["model_name_or_path"],
|
args["model_name_or_path"],
|
||||||
from_pt=bool(".bin" in args["model_name_or_path"]),
|
from_pt=bool(".bin" in args["model_name_or_path"]),
|
||||||
config=config,
|
config=config,
|
||||||
@@ -562,7 +551,7 @@ def main(_):
|
|||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if args["do_eval"]:
|
if args["do_eval"]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -584,7 +573,7 @@ def main(_):
|
|||||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
|
||||||
|
|
||||||
y_true, y_pred, eval_loss = evaluate(
|
y_true, y_pred, eval_loss = evaluate(
|
||||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||||
@@ -611,8 +600,8 @@ def main(_):
|
|||||||
writer.write("\n")
|
writer.write("\n")
|
||||||
|
|
||||||
if args["do_predict"]:
|
if args["do_predict"]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||||
model = model_class.from_pretrained(args["output_dir"])
|
model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
|
||||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||||
predict_dataset, _ = load_and_cache_examples(
|
predict_dataset, _ = load_and_cache_examples(
|
||||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
||||||
|
|||||||
@@ -30,32 +30,12 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
AlbertConfig,
|
AutoConfig,
|
||||||
AlbertForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AlbertTokenizer,
|
AutoTokenizer,
|
||||||
BertConfig,
|
|
||||||
BertForSequenceClassification,
|
|
||||||
BertTokenizer,
|
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForSequenceClassification,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
FlaubertConfig,
|
|
||||||
FlaubertForSequenceClassification,
|
|
||||||
FlaubertTokenizer,
|
|
||||||
RobertaConfig,
|
|
||||||
RobertaForSequenceClassification,
|
|
||||||
RobertaTokenizer,
|
|
||||||
XLMConfig,
|
|
||||||
XLMForSequenceClassification,
|
|
||||||
XLMRobertaConfig,
|
|
||||||
XLMRobertaForSequenceClassification,
|
|
||||||
XLMRobertaTokenizer,
|
|
||||||
XLMTokenizer,
|
|
||||||
XLNetConfig,
|
|
||||||
XLNetForSequenceClassification,
|
|
||||||
XLNetTokenizer,
|
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from transformers import glue_compute_metrics as compute_metrics
|
from transformers import glue_compute_metrics as compute_metrics
|
||||||
@@ -72,33 +52,10 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())
|
||||||
(
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
tuple(conf.pretrained_config_archive_map.keys())
|
|
||||||
for conf in (
|
|
||||||
BertConfig,
|
|
||||||
XLNetConfig,
|
|
||||||
XLMConfig,
|
|
||||||
RobertaConfig,
|
|
||||||
DistilBertConfig,
|
|
||||||
AlbertConfig,
|
|
||||||
XLMRobertaConfig,
|
|
||||||
FlaubertConfig,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
(),
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
|
||||||
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
|
||||||
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
|
||||||
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
|
||||||
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
|
||||||
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
|
||||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
|
||||||
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
@@ -442,7 +399,7 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
@@ -622,19 +579,18 @@ def main():
|
|||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config = AutoConfig.from_pretrained(
|
||||||
config = config_class.from_pretrained(
|
|
||||||
args.config_name if args.config_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
finetuning_task=args.task_name,
|
finetuning_task=args.task_name,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
)
|
)
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
)
|
)
|
||||||
model = model_class.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
@@ -673,14 +629,14 @@ def main():
|
|||||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(
|
checkpoints = list(
|
||||||
@@ -692,7 +648,7 @@ def main():
|
|||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
|
|||||||
@@ -38,28 +38,15 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
CONFIG_MAPPING,
|
||||||
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
BertConfig,
|
AutoConfig,
|
||||||
BertForMaskedLM,
|
AutoModelWithLMHead,
|
||||||
BertTokenizer,
|
AutoTokenizer,
|
||||||
CamembertConfig,
|
|
||||||
CamembertForMaskedLM,
|
|
||||||
CamembertTokenizer,
|
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForMaskedLM,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
GPT2Config,
|
|
||||||
GPT2LMHeadModel,
|
|
||||||
GPT2Tokenizer,
|
|
||||||
OpenAIGPTConfig,
|
|
||||||
OpenAIGPTLMHeadModel,
|
|
||||||
OpenAIGPTTokenizer,
|
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
RobertaConfig,
|
|
||||||
RobertaForMaskedLM,
|
|
||||||
RobertaTokenizer,
|
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,14 +60,8 @@ except ImportError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
|
||||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
|
||||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
|
||||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
|
||||||
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
@@ -693,23 +674,21 @@ def main():
|
|||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||||
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
|
||||||
|
|
||||||
if args.config_name:
|
if args.config_name:
|
||||||
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
||||||
elif args.model_name_or_path:
|
elif args.model_name_or_path:
|
||||||
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||||
else:
|
else:
|
||||||
config = config_class()
|
config = CONFIG_MAPPING[args.model_type]()
|
||||||
|
|
||||||
if args.tokenizer_name:
|
if args.tokenizer_name:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
||||||
elif args.model_name_or_path:
|
elif args.model_name_or_path:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
|
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
|
||||||
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
|
"and load it from here, using --tokenizer_name".format(AutoTokenizer.__name__)
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.block_size <= 0:
|
if args.block_size <= 0:
|
||||||
@@ -719,7 +698,7 @@ def main():
|
|||||||
args.block_size = min(args.block_size, tokenizer.max_len)
|
args.block_size = min(args.block_size, tokenizer.max_len)
|
||||||
|
|
||||||
if args.model_name_or_path:
|
if args.model_name_or_path:
|
||||||
model = model_class.from_pretrained(
|
model = AutoModelWithLMHead.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
@@ -727,7 +706,7 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = model_class(config=config)
|
model = AutoModelWithLMHead(config=config)
|
||||||
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
@@ -768,8 +747,8 @@ def main():
|
|||||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = AutoModelWithLMHead.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
@@ -786,7 +765,7 @@ def main():
|
|||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = AutoModelWithLMHead.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
|
|||||||
@@ -30,29 +30,12 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
AlbertConfig,
|
AutoConfig,
|
||||||
AlbertForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AlbertTokenizer,
|
AutoTokenizer,
|
||||||
BertConfig,
|
|
||||||
BertForQuestionAnswering,
|
|
||||||
BertTokenizer,
|
|
||||||
CamembertConfig,
|
|
||||||
CamembertForQuestionAnswering,
|
|
||||||
CamembertTokenizer,
|
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForQuestionAnswering,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
RobertaConfig,
|
|
||||||
RobertaForQuestionAnswering,
|
|
||||||
RobertaTokenizer,
|
|
||||||
XLMConfig,
|
|
||||||
XLMForQuestionAnswering,
|
|
||||||
XLMTokenizer,
|
|
||||||
XLNetConfig,
|
|
||||||
XLNetForQuestionAnswering,
|
|
||||||
XLNetTokenizer,
|
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
squad_convert_examples_to_features,
|
squad_convert_examples_to_features,
|
||||||
)
|
)
|
||||||
@@ -72,23 +55,10 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||||
(
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
tuple(conf.pretrained_config_archive_map.keys())
|
|
||||||
for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig)
|
|
||||||
),
|
|
||||||
(),
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
|
||||||
"camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer),
|
|
||||||
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
|
||||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
|
||||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
|
||||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
|
||||||
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
@@ -513,7 +483,7 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
@@ -757,17 +727,16 @@ def main():
|
|||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config = AutoConfig.from_pretrained(
|
||||||
config = config_class.from_pretrained(
|
|
||||||
args.config_name if args.config_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
)
|
)
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
)
|
)
|
||||||
model = model_class.from_pretrained(
|
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
@@ -817,8 +786,8 @@ def main():
|
|||||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
|
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
@@ -842,7 +811,7 @@ def main():
|
|||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint) # , force_download=True)
|
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) # , force_download=True)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
|
|||||||
@@ -158,6 +158,12 @@ if is_torch_available():
|
|||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
MODEL_MAPPING,
|
||||||
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_bert import (
|
from .modeling_bert import (
|
||||||
@@ -317,6 +323,12 @@ if is_tf_available():
|
|||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
TF_MODEL_MAPPING,
|
||||||
|
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
|
|||||||
@@ -28,20 +28,12 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
BertConfig,
|
AutoConfig,
|
||||||
BertForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
BertTokenizer,
|
AutoTokenizer,
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForQuestionAnswering,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
XLMConfig,
|
|
||||||
XLMForQuestionAnswering,
|
|
||||||
XLMTokenizer,
|
|
||||||
XLNetConfig,
|
|
||||||
XLNetForQuestionAnswering,
|
|
||||||
XLNetTokenizer,
|
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from utils_squad import (
|
from utils_squad import (
|
||||||
@@ -68,16 +60,10 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
|
||||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
|
||||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
|
||||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
@@ -418,7 +404,7 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
@@ -626,17 +612,16 @@ def main():
|
|||||||
# download model & vocab
|
# download model & vocab
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config = AutoConfig.from_pretrained(
|
||||||
config = config_class.from_pretrained(
|
|
||||||
args.config_name if args.config_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
)
|
)
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
)
|
)
|
||||||
model = model_class.from_pretrained(
|
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
@@ -687,8 +672,8 @@ def main():
|
|||||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
@@ -706,7 +691,7 @@ def main():
|
|||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
|
|||||||
Reference in New Issue
Block a user