Cleanup & Evaluation now works

This commit is contained in:
Lysandre
2019-11-28 16:03:56 -05:00
parent 0669c1fcd1
commit bd41e8292a
2 changed files with 20 additions and 38 deletions

View File

@@ -16,7 +16,7 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor
import argparse
import logging
@@ -45,9 +45,9 @@ from transformers import (WEIGHTS_NAME, BertConfig,
XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
from utils_squad import (RawResult, write_predictions,
from utils_squad import (convert_examples_to_features as old_convert, read_squad_examples as old_read, RawResult, write_predictions,
RawResultExtended, write_predictions_extended)
# The follwing import is the official SQuAD evaluation script (2.0).
@@ -304,28 +304,20 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
keep_n_examples = 1000
processor = SquadV1Processor()
values = processor.get_dev_examples("examples/squad")
examples = values[:keep_n_examples]
features = squad_convert_examples_to_features(examples=exampless,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
print("DONE")
import sys
sys.exit()
processor = SquadV2Processor()
examples = processor.get_dev_examples("examples/squad") if evaluate else processor.get_train_examples("examples/squad")
features = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
@@ -335,8 +327,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate: