unified tokenizer api and serialization + tests
This commit is contained in:
@@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers.modeling_bert import BertForSequenceClassification
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
|
||||
XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
||||
XLMTokenizer)
|
||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||
|
||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||
@@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': BertForSequenceClassification,
|
||||
'xlnet': XLNetForSequenceClassification,
|
||||
'xlm': XLMForSequenceClassification,
|
||||
}
|
||||
|
||||
TOKENIZER_CLASSES = {
|
||||
'bert': BertTokenizer,
|
||||
'xlnet': XLNetTokenizer,
|
||||
'xlm': XLMTokenizer,
|
||||
}
|
||||
|
||||
def train(args, train_features, model):
|
||||
""" Train the model """
|
||||
@@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Num examples = %d", len(eval_features))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
@@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
|
||||
examples = processor.get_dev_examples(args.data_dir)
|
||||
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
|
||||
'dev' if eval else 'train',
|
||||
list(filter(None, args.bert_model.split('/'))).pop(),
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
|
||||
@@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
|
||||
cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
@@ -230,12 +252,10 @@ def main():
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train.")
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
@@ -243,9 +263,8 @@ def main():
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.")
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
@@ -263,8 +282,7 @@ def main():
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
@@ -331,8 +349,11 @@ def main():
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
args.model_type = args.model_name.lower().split('-')[0]
|
||||
args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
|
||||
args.model_class = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
@@ -359,27 +380,16 @@ def main():
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForSequenceClassification.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
|
||||
torch.save(args, output_args_file)
|
||||
else:
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model)
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
model.to(args.device)
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = args.model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
|
||||
@@ -211,8 +211,8 @@ def main():
|
||||
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
||||
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
@@ -369,8 +369,8 @@ def main():
|
||||
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
||||
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||
|
||||
@@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
mask_padding_with_zero=True):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||
"""
|
||||
@@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
[str(x) for x in tokens]))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info(
|
||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("label: %s (id = %d)" % (example.label, label_id))
|
||||
|
||||
features.append(
|
||||
|
||||
Reference in New Issue
Block a user