Works for XLNet

This commit is contained in:
Lysandre
2019-11-22 14:36:49 -05:00
committed by LysandreJik
parent a5a8a6175f
commit c3ba645237
2 changed files with 50 additions and 72 deletions

View File

@@ -16,6 +16,7 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor
import argparse
import logging
@@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
from utils_squad import (read_squad_examples, convert_examples_to_features,
RawResult, write_predictions,
from utils_squad import (RawResult, write_predictions,
RawResultExtended, write_predictions_extended)
# The follwing import is the official SQuAD evaluation script (2.0).
@@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
results = evaluate_on_squad(evaluate_options)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
@@ -308,9 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples = read_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
examples = examples[:10]
features = convert_examples_to_features(examples=examples,
keep_n_examples = 1000
processor = SquadV1Processor()
values = processor.get_dev_examples("examples/squad")
examples = values[:keep_n_examples]
features = squad_convert_examples_to_features(examples=exampless,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
@@ -320,29 +321,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
exampless = sread_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
exampless = exampless[:10]
features2 = squad_convert_examples_to_features(examples=exampless,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
print(features2)
for i in range(len(features)):
assert features[i] == features2[i]
print("Equal")
print("DONE")
import sys
sys.exit()
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)