wip
This commit is contained in:
@@ -23,7 +23,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import timeit
|
import timeit
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
@@ -45,7 +44,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
|||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
|
||||||
|
|
||||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||||
RawResult, write_predictions,
|
RawResult, write_predictions,
|
||||||
@@ -309,6 +308,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
examples = read_squad_examples(input_file=input_file,
|
examples = read_squad_examples(input_file=input_file,
|
||||||
is_training=not evaluate,
|
is_training=not evaluate,
|
||||||
version_2_with_negative=args.version_2_with_negative)
|
version_2_with_negative=args.version_2_with_negative)
|
||||||
|
|
||||||
|
examples = examples[:10]
|
||||||
features = convert_examples_to_features(examples=examples,
|
features = convert_examples_to_features(examples=examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
@@ -319,6 +320,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||||
|
|
||||||
|
exampless = sread_squad_examples(input_file=input_file,
|
||||||
|
is_training=not evaluate,
|
||||||
|
version_2_with_negative=args.version_2_with_negative)
|
||||||
|
exampless = exampless[:10]
|
||||||
|
features2 = squad_convert_examples_to_features(examples=exampless,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
doc_stride=args.doc_stride,
|
||||||
|
max_query_length=args.max_query_length,
|
||||||
|
is_training=not evaluate,
|
||||||
|
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||||
|
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||||
|
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||||
|
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||||
|
|
||||||
|
print(features2)
|
||||||
|
|
||||||
|
for i in range(len(features)):
|
||||||
|
assert features[i] == features2[i]
|
||||||
|
print("Equal")
|
||||||
|
|
||||||
|
print("DONE")
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ from .data import (is_sklearn_available,
|
|||||||
InputExample, InputFeatures, DataProcessor,
|
InputExample, InputFeatures, DataProcessor,
|
||||||
glue_output_modes, glue_convert_examples_to_features,
|
glue_output_modes, glue_convert_examples_to_features,
|
||||||
glue_processors, glue_tasks_num_labels,
|
glue_processors, glue_tasks_num_labels,
|
||||||
squad_convert_examples_to_features, SquadFeatures)
|
squad_convert_examples_to_features, SquadFeatures,
|
||||||
|
SquadExample, read_squad_examples)
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .data import glue_compute_metrics
|
from .data import glue_compute_metrics
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures
|
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures
|
||||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||||
from .processors import squad_convert_examples_to_features
|
from .processors import squad_convert_examples_to_features, SquadExample, read_squad_examples
|
||||||
|
|
||||||
from .metrics import is_sklearn_available
|
from .metrics import is_sklearn_available
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .utils import InputExample, InputFeatures, DataProcessor
|
from .utils import InputExample, InputFeatures, DataProcessor
|
||||||
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||||
from .squad import squad_convert_examples_to_features, SquadFeatures
|
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, read_squad_examples
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ from tqdm import tqdm
|
|||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
from .utils import DataProcessor, InputExample, InputFeatures
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
from ...file_utils import is_tf_available
|
from ...file_utils import is_tf_available
|
||||||
|
|
||||||
@@ -11,6 +13,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||||
doc_stride, max_query_length, is_training,
|
doc_stride, max_query_length, is_training,
|
||||||
cls_token_at_end=False,
|
cls_token_at_end=False,
|
||||||
@@ -265,6 +268,125 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||||
|
"""Read a SQuAD json file into a list of SquadExample."""
|
||||||
|
with open(input_file, "r", encoding='utf-8') as reader:
|
||||||
|
input_data = json.load(reader)["data"]
|
||||||
|
|
||||||
|
def is_whitespace(c):
|
||||||
|
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for entry in input_data:
|
||||||
|
for paragraph in entry["paragraphs"]:
|
||||||
|
paragraph_text = paragraph["context"]
|
||||||
|
doc_tokens = []
|
||||||
|
char_to_word_offset = []
|
||||||
|
prev_is_whitespace = True
|
||||||
|
for c in paragraph_text:
|
||||||
|
if is_whitespace(c):
|
||||||
|
prev_is_whitespace = True
|
||||||
|
else:
|
||||||
|
if prev_is_whitespace:
|
||||||
|
doc_tokens.append(c)
|
||||||
|
else:
|
||||||
|
doc_tokens[-1] += c
|
||||||
|
prev_is_whitespace = False
|
||||||
|
char_to_word_offset.append(len(doc_tokens) - 1)
|
||||||
|
|
||||||
|
for qa in paragraph["qas"]:
|
||||||
|
qas_id = qa["id"]
|
||||||
|
question_text = qa["question"]
|
||||||
|
start_position = None
|
||||||
|
end_position = None
|
||||||
|
orig_answer_text = None
|
||||||
|
is_impossible = False
|
||||||
|
if is_training:
|
||||||
|
if version_2_with_negative:
|
||||||
|
is_impossible = qa["is_impossible"]
|
||||||
|
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||||
|
raise ValueError(
|
||||||
|
"For training, each question should have exactly 1 answer.")
|
||||||
|
if not is_impossible:
|
||||||
|
answer = qa["answers"][0]
|
||||||
|
orig_answer_text = answer["text"]
|
||||||
|
answer_offset = answer["answer_start"]
|
||||||
|
answer_length = len(orig_answer_text)
|
||||||
|
start_position = char_to_word_offset[answer_offset]
|
||||||
|
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
||||||
|
# Only add answers where the text can be exactly recovered from the
|
||||||
|
# document. If this CAN'T happen it's likely due to weird Unicode
|
||||||
|
# stuff so we will just skip the example.
|
||||||
|
#
|
||||||
|
# Note that this means for training mode, every example is NOT
|
||||||
|
# guaranteed to be preserved.
|
||||||
|
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
||||||
|
cleaned_answer_text = " ".join(
|
||||||
|
whitespace_tokenize(orig_answer_text))
|
||||||
|
if actual_text.find(cleaned_answer_text) == -1:
|
||||||
|
logger.warning("Could not find answer: '%s' vs. '%s'",
|
||||||
|
actual_text, cleaned_answer_text)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
start_position = -1
|
||||||
|
end_position = -1
|
||||||
|
orig_answer_text = ""
|
||||||
|
|
||||||
|
example = SquadExample(
|
||||||
|
qas_id=qas_id,
|
||||||
|
question_text=question_text,
|
||||||
|
doc_tokens=doc_tokens,
|
||||||
|
orig_answer_text=orig_answer_text,
|
||||||
|
start_position=start_position,
|
||||||
|
end_position=end_position,
|
||||||
|
is_impossible=is_impossible)
|
||||||
|
examples.append(example)
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class SquadExample(object):
|
||||||
|
"""
|
||||||
|
A single training/test example for the Squad dataset.
|
||||||
|
For examples without an answer, the start and end position are -1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
qas_id,
|
||||||
|
question_text,
|
||||||
|
doc_tokens,
|
||||||
|
orig_answer_text=None,
|
||||||
|
start_position=None,
|
||||||
|
end_position=None,
|
||||||
|
is_impossible=None):
|
||||||
|
self.qas_id = qas_id
|
||||||
|
self.question_text = question_text
|
||||||
|
self.doc_tokens = doc_tokens
|
||||||
|
self.orig_answer_text = orig_answer_text
|
||||||
|
self.start_position = start_position
|
||||||
|
self.end_position = end_position
|
||||||
|
self.is_impossible = is_impossible
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = ""
|
||||||
|
s += "qas_id: %s" % (self.qas_id)
|
||||||
|
s += ", question_text: %s" % (
|
||||||
|
self.question_text)
|
||||||
|
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||||
|
if self.start_position:
|
||||||
|
s += ", start_position: %d" % (self.start_position)
|
||||||
|
if self.end_position:
|
||||||
|
s += ", end_position: %d" % (self.end_position)
|
||||||
|
if self.is_impossible:
|
||||||
|
s += ", is_impossible: %r" % (self.is_impossible)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class SquadFeatures(object):
|
class SquadFeatures(object):
|
||||||
"""A single set of features of data."""
|
"""A single set of features of data."""
|
||||||
|
|
||||||
|
|||||||
@@ -605,6 +605,10 @@ class PreTrainedTokenizer(object):
|
|||||||
vocabularies (BPE/SentencePieces/WordPieces).
|
vocabularies (BPE/SentencePieces/WordPieces).
|
||||||
|
|
||||||
Take care of added tokens.
|
Take care of added tokens.
|
||||||
|
|
||||||
|
text: The sequence to be encoded.
|
||||||
|
return_tokens_mapped_to_origin: (optional) Set to True to return the index of each token in the initial whitespace tokenization. (default False).
|
||||||
|
**kwargs: passed to the child `self.tokenize()` method
|
||||||
"""
|
"""
|
||||||
def split_on_token(tok, text):
|
def split_on_token(tok, text):
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
Reference in New Issue
Block a user