add doc string

This commit is contained in:
erenup
2019-09-16 11:50:18 +08:00
parent 6e1ac34e2b
commit 4812a5a767
4 changed files with 158 additions and 36 deletions

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet)."""
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
from __future__ import absolute_import, division, print_function
@@ -44,7 +44,7 @@ from utils_multiple_choice import (convert_examples_to_features, processors)
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
@@ -208,7 +208,6 @@ def train(args, train_dataset, model, tokenizer):
def evaluate(args, model, tokenizer, prefix="", test=False):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = (args.task_name,)
eval_outputs_dirs = (args.output_dir,)
@@ -259,7 +258,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
result = {"eval_acc": acc, "eval_loss": eval_loss}
results.update(result)
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt")
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
@@ -522,9 +521,9 @@ def main():
if not args.do_train:
args.output_dir = args.model_name_or_path
checkpoints = [args.output_dir]
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
# if args.eval_all_checkpoints: # can not use this to do test!!
# checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
# logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
from __future__ import absolute_import, division, print_function
@@ -38,11 +38,10 @@ class InputExample(object):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
example_id: Unique id for the example.
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
question: string. The untokenized text of the second sequence (qustion).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
@@ -73,7 +72,7 @@ class InputFeatures(object):
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
"""Base class for data converters for multiple choice data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
@@ -84,7 +83,7 @@ class DataProcessor(object):
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
"""Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError()
def get_labels(self):
@@ -93,7 +92,7 @@ class DataProcessor(object):
class RaceProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the RACE data set."""
def get_train_examples(self, data_dir):
"""See base class."""
@@ -152,13 +151,13 @@ class RaceProcessor(DataProcessor):
InputExample(
example_id=race_id,
question=question,
contexts=[article, article, article, article],
contexts=[article, article, article, article], # this is not efficient but convenient
endings=[options[0], options[1], options[2], options[3]],
label=truth))
return examples
class SwagProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the SWAG data set."""
def get_train_examples(self, data_dir):
"""See base class."""
@@ -172,9 +171,12 @@ class SwagProcessor(DataProcessor):
def get_test_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {} test".format(data_dir))
logger.info("LOOKING AT {} dev".format(data_dir))
raise ValueError(
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
"setting!"
)
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1", "2", "3"]
@@ -213,7 +215,7 @@ class SwagProcessor(DataProcessor):
class ArcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the ARC data set (request from allennlp)."""
def get_train_examples(self, data_dir):
"""See base class."""
@@ -242,6 +244,7 @@ class ArcProcessor(DataProcessor):
def _create_examples(self, lines, type):
"""Creates examples for the training and dev sets."""
#There are two types of labels. They should be normalized
def normalize(truth):
if truth in "ABCD":
return ord(truth) - ord("A")
@@ -256,6 +259,7 @@ class ArcProcessor(DataProcessor):
four_choice = 0
five_choice = 0
other_choices = 0
# we deleted example which has more than or less than four choices
for line in tqdm.tqdm(lines, desc="read arc data"):
data_raw = json.loads(line.strip("\n"))
if len(data_raw["question"]["choices"]) == 3:
@@ -274,7 +278,6 @@ class ArcProcessor(DataProcessor):
question = question_choices["stem"]
id = data_raw["id"]
options = question_choices["choices"]
if len(options) == 4:
examples.append(
InputExample(
@@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
tokens_a = tokenizer.tokenize(context)
tokens_b = None
if example.question.find("_") != -1:
#this is for cloze question
tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
else:
tokens_b = tokenizer.tokenize(example.question)
tokens_b += [sep_token]
if sep_token_extra:
tokens_b += [sep_token]
tokens_b += tokenizer.tokenize(ending)
tokens_b = tokenizer.tokenize(example.question + " " + ending)
# you can add seq token between quesiotn and ending. This does not make too much difference.
# tokens_b = tokenizer.tokenize(example.question)
# tokens_b += [sep_token]
# if sep_token_extra:
# tokens_b += [sep_token]
# tokens_b += tokenizer.tokenize(ending)
special_tokens_count = 4 if sep_token_extra else 3
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
@@ -427,15 +433,18 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
# if len(tokens_a) > len(tokens_b):
# tokens_a.pop()
# else:
# tokens_b.pop()
tokens_a.pop()
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
logger.info('Attention! you are removing from question + options. Try to use a bigger max seq length!')
tokens_b.pop()
processors = {