Merge branch 'master' into run_multiple_choice_merge

# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
This commit is contained in:
erenup
2019-09-18 21:43:46 +08:00
5 changed files with 198 additions and 39 deletions

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
from __future__ import absolute_import, division, print_function
@@ -38,11 +38,10 @@ class InputExample(object):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
example_id: Unique id for the example.
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
question: string. The untokenized text of the second sequence (qustion).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
@@ -73,7 +72,7 @@ class InputFeatures(object):
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
"""Base class for data converters for multiple choice data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
@@ -84,7 +83,7 @@ class DataProcessor(object):
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
"""Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError()
def get_labels(self):
@@ -93,7 +92,7 @@ class DataProcessor(object):
class RaceProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the RACE data set."""
def get_train_examples(self, data_dir):
"""See base class."""
@@ -152,13 +151,13 @@ class RaceProcessor(DataProcessor):
InputExample(
example_id=race_id,
question=question,
contexts=[article, article, article, article],
contexts=[article, article, article, article], # this is not efficient but convenient
endings=[options[0], options[1], options[2], options[3]],
label=truth))
return examples
class SwagProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the SWAG data set."""
def get_train_examples(self, data_dir):
"""See base class."""
@@ -172,9 +171,12 @@ class SwagProcessor(DataProcessor):
def get_test_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {} test".format(data_dir))
logger.info("LOOKING AT {} dev".format(data_dir))
raise ValueError(
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
"setting!"
)
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1", "2", "3"]
@@ -213,7 +215,7 @@ class SwagProcessor(DataProcessor):
class ArcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the ARC data set (request from allennlp)."""
def get_train_examples(self, data_dir):
"""See base class."""
@@ -242,6 +244,7 @@ class ArcProcessor(DataProcessor):
def _create_examples(self, lines, type):
"""Creates examples for the training and dev sets."""
#There are two types of labels. They should be normalized
def normalize(truth):
if truth in "ABCD":
return ord(truth) - ord("A")
@@ -256,6 +259,7 @@ class ArcProcessor(DataProcessor):
four_choice = 0
five_choice = 0
other_choices = 0
# we deleted example which has more than or less than four choices
for line in tqdm.tqdm(lines, desc="read arc data"):
data_raw = json.loads(line.strip("\n"))
if len(data_raw["question"]["choices"]) == 3:
@@ -274,7 +278,6 @@ class ArcProcessor(DataProcessor):
question = question_choices["stem"]
id = data_raw["id"]
options = question_choices["choices"]
if len(options) == 4:
examples.append(
InputExample(
@@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
tokens_a = tokenizer.tokenize(context)
tokens_b = None
if example.question.find("_") != -1:
#this is for cloze question
tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
else:
tokens_b = tokenizer.tokenize(example.question)
tokens_b += [sep_token]
if sep_token_extra:
tokens_b += [sep_token]
tokens_b += tokenizer.tokenize(ending)
tokens_b = tokenizer.tokenize(example.question + " " + ending)
# you can add seq token between quesiotn and ending. This does not make too much difference.
# tokens_b = tokenizer.tokenize(example.question)
# tokens_b += [sep_token]
# if sep_token_extra:
# tokens_b += [sep_token]
# tokens_b += tokenizer.tokenize(ending)
special_tokens_count = 4 if sep_token_extra else 3
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
@@ -427,15 +433,20 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
# if len(tokens_a) > len(tokens_b):
# tokens_a.pop()
# else:
# tokens_b.pop()
tokens_a.pop()
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
logger.info('Attention! you are removing from token_b (swag task is ok). '
'If you are training ARC and RACE (you are poping question + options), '
'you need to try to use a bigger max seq length!')
tokens_b.pop()
processors = {