Merge branch 'refs/heads/squad_roberta'

# Conflicts:
#	transformers/data/processors/squad.py
This commit is contained in:
erenup
2019-12-14 08:53:59 +08:00
6 changed files with 424 additions and 185 deletions

View File

@@ -39,6 +39,7 @@ from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer,
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering,
@@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
@@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'start_positions': batch[3],
'end_positions': batch[4]
'end_positions': batch[4],
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
@@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad():
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1]
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions
@@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)
@@ -363,7 +360,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
return_dataset='pt'
return_dataset='pt',
threads=args.threads,
)
if args.local_rank in [-1, 0]:
@@ -481,6 +479,8 @@ def main():
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--threads', type=int, default=1, help='multiple threads for converting example to features')
args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: