This commit is contained in:
Lysandre
2019-11-19 09:49:55 -05:00
committed by LysandreJik
parent ea52f82455
commit 72e506b22e
6 changed files with 157 additions and 5 deletions

View File

@@ -23,7 +23,6 @@ import os
import random
import glob
import timeit
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
@@ -45,7 +44,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
from utils_squad import (read_squad_examples, convert_examples_to_features,
RawResult, write_predictions,
@@ -309,6 +308,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples = read_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
examples = examples[:10]
features = convert_examples_to_features(examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
@@ -319,6 +320,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
exampless = sread_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
exampless = exampless[:10]
features2 = squad_convert_examples_to_features(examples=exampless,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
print(features2)
for i in range(len(features)):
assert features[i] == features2[i]
print("Equal")
print("DONE")
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)