Works for XLNet
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from transformers.data.processors.squad import SquadV1Processor
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
|
||||
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
from utils_squad import (RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
@@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
@@ -308,9 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
|
||||
examples = examples[:10]
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
keep_n_examples = 1000
|
||||
processor = SquadV1Processor()
|
||||
values = processor.get_dev_examples("examples/squad")
|
||||
examples = values[:keep_n_examples]
|
||||
features = squad_convert_examples_to_features(examples=exampless,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
@@ -320,29 +321,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||
|
||||
exampless = sread_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
exampless = exampless[:10]
|
||||
features2 = squad_convert_examples_to_features(examples=exampless,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||
|
||||
print(features2)
|
||||
|
||||
for i in range(len(features)):
|
||||
assert features[i] == features2[i]
|
||||
print("Equal")
|
||||
|
||||
print("DONE")
|
||||
|
||||
import sys
|
||||
sys.exit()
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
|
||||
Reference in New Issue
Block a user