Convert indentation from 2 spaces to 4 spaces
This commit is contained in:
@@ -63,379 +63,379 @@ flags.DEFINE_float(
|
||||
|
||||
|
||||
class TrainingInstance(object):
|
||||
"""A single training instance (sentence pair)."""
|
||||
"""A single training instance (sentence pair)."""
|
||||
|
||||
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||
is_random_next):
|
||||
self.tokens = tokens
|
||||
self.segment_ids = segment_ids
|
||||
self.is_random_next = is_random_next
|
||||
self.masked_lm_positions = masked_lm_positions
|
||||
self.masked_lm_labels = masked_lm_labels
|
||||
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||
is_random_next):
|
||||
self.tokens = tokens
|
||||
self.segment_ids = segment_ids
|
||||
self.is_random_next = is_random_next
|
||||
self.masked_lm_positions = masked_lm_positions
|
||||
self.masked_lm_labels = masked_lm_labels
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.tokens]))
|
||||
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||
s += "is_random_next: %s\n" % self.is_random_next
|
||||
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||
[str(x) for x in self.masked_lm_positions]))
|
||||
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||
s += "\n"
|
||||
return s
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.tokens]))
|
||||
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||
s += "is_random_next: %s\n" % self.is_random_next
|
||||
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||
[str(x) for x in self.masked_lm_positions]))
|
||||
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||
max_predictions_per_seq, output_files):
|
||||
"""Create TF example files from `TrainingInstance`s."""
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||
"""Create TF example files from `TrainingInstance`s."""
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||
|
||||
writer_index = 0
|
||||
writer_index = 0
|
||||
|
||||
total_written = 0
|
||||
for (inst_index, instance) in enumerate(instances):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = list(instance.segment_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
total_written = 0
|
||||
for (inst_index, instance) in enumerate(instances):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = list(instance.segment_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
masked_lm_positions = list(instance.masked_lm_positions)
|
||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||
masked_lm_positions = list(instance.masked_lm_positions)
|
||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||
|
||||
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||
masked_lm_positions.append(0)
|
||||
masked_lm_ids.append(0)
|
||||
masked_lm_weights.append(0.0)
|
||||
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||
masked_lm_positions.append(0)
|
||||
masked_lm_ids.append(0)
|
||||
masked_lm_weights.append(0.0)
|
||||
|
||||
next_sentence_label = 1 if instance.is_random_next else 0
|
||||
next_sentence_label = 1 if instance.is_random_next else 0
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(input_ids)
|
||||
features["input_mask"] = create_int_feature(input_mask)
|
||||
features["segment_ids"] = create_int_feature(segment_ids)
|
||||
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(input_ids)
|
||||
features["input_mask"] = create_int_feature(input_mask)
|
||||
features["segment_ids"] = create_int_feature(segment_ids)
|
||||
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
|
||||
writers[writer_index].write(tf_example.SerializeToString())
|
||||
writer_index = (writer_index + 1) % len(writers)
|
||||
writers[writer_index].write(tf_example.SerializeToString())
|
||||
writer_index = (writer_index + 1) % len(writers)
|
||||
|
||||
total_written += 1
|
||||
total_written += 1
|
||||
|
||||
if inst_index < 20:
|
||||
tf.logging.info("*** Example ***")
|
||||
tf.logging.info("tokens: %s" % " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||
if inst_index < 20:
|
||||
tf.logging.info("*** Example ***")
|
||||
tf.logging.info("tokens: %s" % " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
values = []
|
||||
if feature.int64_list.value:
|
||||
values = feature.int64_list.value
|
||||
elif feature.float_list.value:
|
||||
values = feature.float_list.value
|
||||
tf.logging.info(
|
||||
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
values = []
|
||||
if feature.int64_list.value:
|
||||
values = feature.int64_list.value
|
||||
elif feature.float_list.value:
|
||||
values = feature.float_list.value
|
||||
tf.logging.info(
|
||||
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
tf.logging.info("Wrote %d total instances", total_written)
|
||||
tf.logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_float_feature(values):
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
return feature
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_training_instances(input_files, tokenizer, max_seq_length,
|
||||
dupe_factor, short_seq_prob, masked_lm_prob,
|
||||
max_predictions_per_seq, rng):
|
||||
"""Create `TrainingInstance`s from raw text."""
|
||||
all_documents = [[]]
|
||||
"""Create `TrainingInstance`s from raw text."""
|
||||
all_documents = [[]]
|
||||
|
||||
# Input file format:
|
||||
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||
# sentence boundaries for the "next sentence prediction" task).
|
||||
# (2) Blank lines between documents. Document boundaries are needed so
|
||||
# that the "next sentence prediction" task doesn't span between documents.
|
||||
for input_file in input_files:
|
||||
with tf.gfile.GFile(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
# Input file format:
|
||||
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||
# sentence boundaries for the "next sentence prediction" task).
|
||||
# (2) Blank lines between documents. Document boundaries are needed so
|
||||
# that the "next sentence prediction" task doesn't span between documents.
|
||||
for input_file in input_files:
|
||||
with tf.gfile.GFile(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
|
||||
# Empty lines are used as document delimiters
|
||||
if not line:
|
||||
all_documents.append([])
|
||||
tokens = tokenizer.tokenize(line)
|
||||
if tokens:
|
||||
all_documents[-1].append(tokens)
|
||||
# Empty lines are used as document delimiters
|
||||
if not line:
|
||||
all_documents.append([])
|
||||
tokens = tokenizer.tokenize(line)
|
||||
if tokens:
|
||||
all_documents[-1].append(tokens)
|
||||
|
||||
# Remove empty documents
|
||||
all_documents = [x for x in all_documents if x]
|
||||
rng.shuffle(all_documents)
|
||||
# Remove empty documents
|
||||
all_documents = [x for x in all_documents if x]
|
||||
rng.shuffle(all_documents)
|
||||
|
||||
vocab_words = list(tokenizer.vocab.keys())
|
||||
instances = []
|
||||
for _ in range(dupe_factor):
|
||||
for document_index in range(len(all_documents)):
|
||||
instances.extend(
|
||||
create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||
vocab_words = list(tokenizer.vocab.keys())
|
||||
instances = []
|
||||
for _ in range(dupe_factor):
|
||||
for document_index in range(len(all_documents)):
|
||||
instances.extend(
|
||||
create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
|
||||
|
||||
def create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates `TrainingInstance`s for a single document."""
|
||||
document = all_documents[document_index]
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates `TrainingInstance`s for a single document."""
|
||||
document = all_documents[document_index]
|
||||
|
||||
# Account for [CLS], [SEP], [SEP]
|
||||
max_num_tokens = max_seq_length - 3
|
||||
# Account for [CLS], [SEP], [SEP]
|
||||
max_num_tokens = max_seq_length - 3
|
||||
|
||||
# We *usually* want to fill up the entire sequence since we are padding
|
||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||
# computation. However, we *sometimes*
|
||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||
# The `target_seq_length` is just a rough target however, whereas
|
||||
# `max_seq_length` is a hard limit.
|
||||
target_seq_length = max_num_tokens
|
||||
if rng.random() < short_seq_prob:
|
||||
target_seq_length = rng.randint(2, max_num_tokens)
|
||||
# We *usually* want to fill up the entire sequence since we are padding
|
||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||
# computation. However, we *sometimes*
|
||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||
# The `target_seq_length` is just a rough target however, whereas
|
||||
# `max_seq_length` is a hard limit.
|
||||
target_seq_length = max_num_tokens
|
||||
if rng.random() < short_seq_prob:
|
||||
target_seq_length = rng.randint(2, max_num_tokens)
|
||||
|
||||
# We DON'T just concatenate all of the tokens from a document into a long
|
||||
# sequence and choose an arbitrary split point because this would make the
|
||||
# next sentence prediction task too easy. Instead, we split the input into
|
||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||
# input.
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(document):
|
||||
segment = document[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||
# We DON'T just concatenate all of the tokens from a document into a long
|
||||
# sequence and choose an arbitrary split point because this would make the
|
||||
# next sentence prediction task too easy. Instead, we split the input into
|
||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||
# input.
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(document):
|
||||
segment = document[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
|
||||
tokens_b = []
|
||||
# Random next
|
||||
is_random_next = False
|
||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
tokens_b = []
|
||||
# Random next
|
||||
is_random_next = False
|
||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
|
||||
# This should rarely go for more than one iteration for large
|
||||
# corpora. However, just to be careful, we try to make sure that
|
||||
# the random document is not the same as the document
|
||||
# we're processing.
|
||||
for _ in range(10):
|
||||
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||
if random_document_index != document_index:
|
||||
break
|
||||
# This should rarely go for more than one iteration for large
|
||||
# corpora. However, just to be careful, we try to make sure that
|
||||
# the random document is not the same as the document
|
||||
# we're processing.
|
||||
for _ in range(10):
|
||||
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||
if random_document_index != document_index:
|
||||
break
|
||||
|
||||
random_document = all_documents[random_document_index]
|
||||
random_start = rng.randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
# We didn't actually use these segments so we "put them back" so
|
||||
# they don't go to waste.
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
# Actual next
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||
random_document = all_documents[random_document_index]
|
||||
random_start = rng.randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
# We didn't actually use these segments so we "put them back" so
|
||||
# they don't go to waste.
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
# Actual next
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
(tokens, masked_lm_positions,
|
||||
masked_lm_labels) = create_masked_lm_predictions(
|
||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||
instance = TrainingInstance(
|
||||
tokens=tokens,
|
||||
segment_ids=segment_ids,
|
||||
is_random_next=is_random_next,
|
||||
masked_lm_positions=masked_lm_positions,
|
||||
masked_lm_labels=masked_lm_labels)
|
||||
instances.append(instance)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
(tokens, masked_lm_positions,
|
||||
masked_lm_labels) = create_masked_lm_predictions(
|
||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||
instance = TrainingInstance(
|
||||
tokens=tokens,
|
||||
segment_ids=segment_ids,
|
||||
is_random_next=is_random_next,
|
||||
masked_lm_positions=masked_lm_positions,
|
||||
masked_lm_labels=masked_lm_labels)
|
||||
instances.append(instance)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
|
||||
return instances
|
||||
return instances
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||
max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates the predictis for the masked LM objective."""
|
||||
"""Creates the predictis for the masked LM objective."""
|
||||
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
cand_indexes.append(i)
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
cand_indexes.append(i)
|
||||
|
||||
rng.shuffle(cand_indexes)
|
||||
rng.shuffle(cand_indexes)
|
||||
|
||||
output_tokens = list(tokens)
|
||||
output_tokens = list(tokens)
|
||||
|
||||
masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
|
||||
masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
covered_indexes.add(index)
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
covered_indexes.add(index)
|
||||
|
||||
masked_token = None
|
||||
# 80% of the time, replace with [MASK]
|
||||
if rng.random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||
masked_token = None
|
||||
# 80% of the time, replace with [MASK]
|
||||
if rng.random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
output_tokens[index] = masked_token
|
||||
|
||||
masked_lms.append(masked_lm(index=index, label=tokens[index]))
|
||||
masked_lms.append(masked_lm(index=index, label=tokens[index]))
|
||||
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||
|
||||
|
||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if rng.random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if rng.random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
|
||||
tf.logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s", input_file)
|
||||
tf.logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s", input_file)
|
||||
|
||||
rng = random.Random(FLAGS.random_seed)
|
||||
instances = create_training_instances(
|
||||
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||
rng)
|
||||
rng = random.Random(FLAGS.random_seed)
|
||||
instances = create_training_instances(
|
||||
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||
rng)
|
||||
|
||||
output_files = FLAGS.output_file.split(",")
|
||||
tf.logging.info("*** Writing to output files ***")
|
||||
for output_file in output_files:
|
||||
tf.logging.info(" %s", output_file)
|
||||
output_files = FLAGS.output_file.split(",")
|
||||
tf.logging.info("*** Writing to output files ***")
|
||||
for output_file in output_files:
|
||||
tf.logging.info(" %s", output_file)
|
||||
|
||||
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||
FLAGS.max_predictions_per_seq, output_files)
|
||||
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||
FLAGS.max_predictions_per_seq, output_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
tf.app.run()
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
tf.app.run()
|
||||
|
||||
@@ -80,330 +80,330 @@ flags.DEFINE_bool(
|
||||
|
||||
class InputExample(object):
|
||||
|
||||
def __init__(self, unique_id, text_a, text_b):
|
||||
self.unique_id = unique_id
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
def __init__(self, unique_id, text_a, text_b):
|
||||
self.unique_id = unique_id
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
||||
self.unique_id = unique_id
|
||||
self.tokens = tokens
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.input_type_ids = input_type_ids
|
||||
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
||||
self.unique_id = unique_id
|
||||
self.tokens = tokens
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.input_type_ids = input_type_ids
|
||||
|
||||
|
||||
def input_fn_builder(features, seq_length):
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||
|
||||
all_unique_ids = []
|
||||
all_input_ids = []
|
||||
all_input_mask = []
|
||||
all_input_type_ids = []
|
||||
all_unique_ids = []
|
||||
all_input_ids = []
|
||||
all_input_mask = []
|
||||
all_input_type_ids = []
|
||||
|
||||
for feature in features:
|
||||
all_unique_ids.append(feature.unique_id)
|
||||
all_input_ids.append(feature.input_ids)
|
||||
all_input_mask.append(feature.input_mask)
|
||||
all_input_type_ids.append(feature.input_type_ids)
|
||||
for feature in features:
|
||||
all_unique_ids.append(feature.unique_id)
|
||||
all_input_ids.append(feature.input_ids)
|
||||
all_input_mask.append(feature.input_mask)
|
||||
all_input_type_ids.append(feature.input_type_ids)
|
||||
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
|
||||
num_examples = len(features)
|
||||
num_examples = len(features)
|
||||
|
||||
# This is for demo purposes and does NOT scale to large data sets. We do
|
||||
# not use Dataset.from_generator() because that uses tf.py_func which is
|
||||
# not TPU compatible. The right way to load data is with TFRecordReader.
|
||||
d = tf.data.Dataset.from_tensor_slices({
|
||||
"unique_ids":
|
||||
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
|
||||
"input_ids":
|
||||
tf.constant(
|
||||
all_input_ids, shape=[num_examples, seq_length],
|
||||
dtype=tf.int32),
|
||||
"input_mask":
|
||||
tf.constant(
|
||||
all_input_mask,
|
||||
shape=[num_examples, seq_length],
|
||||
dtype=tf.int32),
|
||||
"input_type_ids":
|
||||
tf.constant(
|
||||
all_input_type_ids,
|
||||
shape=[num_examples, seq_length],
|
||||
dtype=tf.int32),
|
||||
})
|
||||
# This is for demo purposes and does NOT scale to large data sets. We do
|
||||
# not use Dataset.from_generator() because that uses tf.py_func which is
|
||||
# not TPU compatible. The right way to load data is with TFRecordReader.
|
||||
d = tf.data.Dataset.from_tensor_slices({
|
||||
"unique_ids":
|
||||
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
|
||||
"input_ids":
|
||||
tf.constant(
|
||||
all_input_ids, shape=[num_examples, seq_length],
|
||||
dtype=tf.int32),
|
||||
"input_mask":
|
||||
tf.constant(
|
||||
all_input_mask,
|
||||
shape=[num_examples, seq_length],
|
||||
dtype=tf.int32),
|
||||
"input_type_ids":
|
||||
tf.constant(
|
||||
all_input_type_ids,
|
||||
shape=[num_examples, seq_length],
|
||||
dtype=tf.int32),
|
||||
})
|
||||
|
||||
d = d.batch(batch_size=batch_size, drop_remainder=False)
|
||||
return d
|
||||
d = d.batch(batch_size=batch_size, drop_remainder=False)
|
||||
return d
|
||||
|
||||
return input_fn
|
||||
return input_fn
|
||||
|
||||
|
||||
def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
|
||||
use_one_hot_embeddings):
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
"""The `model_fn` for TPUEstimator."""
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
"""The `model_fn` for TPUEstimator."""
|
||||
|
||||
unique_ids = features["unique_ids"]
|
||||
input_ids = features["input_ids"]
|
||||
input_mask = features["input_mask"]
|
||||
input_type_ids = features["input_type_ids"]
|
||||
unique_ids = features["unique_ids"]
|
||||
input_ids = features["input_ids"]
|
||||
input_mask = features["input_mask"]
|
||||
input_type_ids = features["input_type_ids"]
|
||||
|
||||
model = modeling.BertModel(
|
||||
config=bert_config,
|
||||
is_training=False,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=input_type_ids,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
model = modeling.BertModel(
|
||||
config=bert_config,
|
||||
is_training=False,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=input_type_ids,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
if mode != tf.estimator.ModeKeys.PREDICT:
|
||||
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
|
||||
if mode != tf.estimator.ModeKeys.PREDICT:
|
||||
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
|
||||
|
||||
tvars = tf.trainable_variables()
|
||||
scaffold_fn = None
|
||||
(assignment_map, _) = modeling.get_assigment_map_from_checkpoint(
|
||||
tvars, init_checkpoint)
|
||||
if use_tpu:
|
||||
tvars = tf.trainable_variables()
|
||||
scaffold_fn = None
|
||||
(assignment_map, _) = modeling.get_assigment_map_from_checkpoint(
|
||||
tvars, init_checkpoint)
|
||||
if use_tpu:
|
||||
|
||||
def tpu_scaffold():
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
return tf.train.Scaffold()
|
||||
def tpu_scaffold():
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
return tf.train.Scaffold()
|
||||
|
||||
scaffold_fn = tpu_scaffold
|
||||
else:
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
scaffold_fn = tpu_scaffold
|
||||
else:
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
|
||||
all_layers = model.get_all_encoder_layers()
|
||||
all_layers = model.get_all_encoder_layers()
|
||||
|
||||
predictions = {
|
||||
"unique_id": unique_ids,
|
||||
}
|
||||
predictions = {
|
||||
"unique_id": unique_ids,
|
||||
}
|
||||
|
||||
for (i, layer_index) in enumerate(layer_indexes):
|
||||
predictions["layer_output_%d" % i] = all_layers[layer_index]
|
||||
for (i, layer_index) in enumerate(layer_indexes):
|
||||
predictions["layer_output_%d" % i] = all_layers[layer_index]
|
||||
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
||||
return output_spec
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
||||
return output_spec
|
||||
|
||||
return model_fn
|
||||
return model_fn
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, seq_length, tokenizer):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
tokens_a = tokenizer.tokenize(example.text_a)
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
tokens_a = tokenizer.tokenize(example.text_a)
|
||||
|
||||
tokens_b = None
|
||||
if example.text_b:
|
||||
tokens_b = tokenizer.tokenize(example.text_b)
|
||||
tokens_b = None
|
||||
if example.text_b:
|
||||
tokens_b = tokenizer.tokenize(example.text_b)
|
||||
|
||||
if tokens_b:
|
||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||
# length is less than the specified length.
|
||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
||||
else:
|
||||
# Account for [CLS] and [SEP] with "- 2"
|
||||
if len(tokens_a) > seq_length - 2:
|
||||
tokens_a = tokens_a[0:(seq_length - 2)]
|
||||
if tokens_b:
|
||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||
# length is less than the specified length.
|
||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
||||
else:
|
||||
# Account for [CLS] and [SEP] with "- 2"
|
||||
if len(tokens_a) > seq_length - 2:
|
||||
tokens_a = tokens_a[0:(seq_length - 2)]
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||
# (b) For single sequences:
|
||||
# tokens: [CLS] the dog is hairy . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0
|
||||
#
|
||||
# Where "type_ids" are used to indicate whether this is the first
|
||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||
# embedding vector (and position vector). This is not *strictly* necessary
|
||||
# since the [SEP] token unambigiously separates the sequences, but it makes
|
||||
# it easier for the model to learn the concept of sequences.
|
||||
#
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = []
|
||||
input_type_ids = []
|
||||
tokens.append("[CLS]")
|
||||
input_type_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
input_type_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
input_type_ids.append(0)
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||
# (b) For single sequences:
|
||||
# tokens: [CLS] the dog is hairy . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0
|
||||
#
|
||||
# Where "type_ids" are used to indicate whether this is the first
|
||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||
# embedding vector (and position vector). This is not *strictly* necessary
|
||||
# since the [SEP] token unambigiously separates the sequences, but it makes
|
||||
# it easier for the model to learn the concept of sequences.
|
||||
#
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = []
|
||||
input_type_ids = []
|
||||
tokens.append("[CLS]")
|
||||
input_type_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
input_type_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
input_type_ids.append(0)
|
||||
|
||||
if tokens_b:
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
input_type_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
input_type_ids.append(1)
|
||||
if tokens_b:
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
input_type_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
input_type_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
input_type_ids.append(0)
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
input_type_ids.append(0)
|
||||
|
||||
assert len(input_ids) == seq_length
|
||||
assert len(input_mask) == seq_length
|
||||
assert len(input_type_ids) == seq_length
|
||||
assert len(input_ids) == seq_length
|
||||
assert len(input_mask) == seq_length
|
||||
assert len(input_type_ids) == seq_length
|
||||
|
||||
if ex_index < 5:
|
||||
tf.logging.info("*** Example ***")
|
||||
tf.logging.info("unique_id: %s" % (example.unique_id))
|
||||
tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
|
||||
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
tf.logging.info(
|
||||
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
|
||||
if ex_index < 5:
|
||||
tf.logging.info("*** Example ***")
|
||||
tf.logging.info("unique_id: %s" % (example.unique_id))
|
||||
tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
|
||||
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
tf.logging.info(
|
||||
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
unique_id=example.unique_id,
|
||||
tokens=tokens,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
input_type_ids=input_type_ids))
|
||||
return features
|
||||
features.append(
|
||||
InputFeatures(
|
||||
unique_id=example.unique_id,
|
||||
tokens=tokens,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
input_type_ids=input_type_ids))
|
||||
return features
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def read_examples(input_file):
|
||||
"""Read a list of `InputExample`s from an input file."""
|
||||
examples = []
|
||||
unique_id = 0
|
||||
with tf.gfile.GFile(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
text_a = None
|
||||
text_b = None
|
||||
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
||||
if m is None:
|
||||
text_a = line
|
||||
else:
|
||||
text_a = m.group(1)
|
||||
text_b = m.group(2)
|
||||
examples.append(
|
||||
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
||||
unique_id += 1
|
||||
return examples
|
||||
"""Read a list of `InputExample`s from an input file."""
|
||||
examples = []
|
||||
unique_id = 0
|
||||
with tf.gfile.GFile(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
text_a = None
|
||||
text_b = None
|
||||
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
||||
if m is None:
|
||||
text_a = line
|
||||
else:
|
||||
text_a = m.group(1)
|
||||
text_b = m.group(2)
|
||||
examples.append(
|
||||
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
||||
unique_id += 1
|
||||
return examples
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
|
||||
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
|
||||
|
||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
master=FLAGS.master,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
num_shards=FLAGS.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
master=FLAGS.master,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
num_shards=FLAGS.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
|
||||
examples = read_examples(FLAGS.input_file)
|
||||
examples = read_examples(FLAGS.input_file)
|
||||
|
||||
features = convert_examples_to_features(
|
||||
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
|
||||
features = convert_examples_to_features(
|
||||
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
|
||||
|
||||
unique_id_to_feature = {}
|
||||
for feature in features:
|
||||
unique_id_to_feature[feature.unique_id] = feature
|
||||
unique_id_to_feature = {}
|
||||
for feature in features:
|
||||
unique_id_to_feature[feature.unique_id] = feature
|
||||
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
init_checkpoint=FLAGS.init_checkpoint,
|
||||
layer_indexes=layer_indexes,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
init_checkpoint=FLAGS.init_checkpoint,
|
||||
layer_indexes=layer_indexes,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
|
||||
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
predict_batch_size=FLAGS.batch_size)
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
predict_batch_size=FLAGS.batch_size)
|
||||
|
||||
input_fn = input_fn_builder(
|
||||
features=features, seq_length=FLAGS.max_seq_length)
|
||||
input_fn = input_fn_builder(
|
||||
features=features, seq_length=FLAGS.max_seq_length)
|
||||
|
||||
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
|
||||
"w")) as writer:
|
||||
for result in estimator.predict(input_fn, yield_single_examples=True):
|
||||
unique_id = int(result["unique_id"])
|
||||
feature = unique_id_to_feature[unique_id]
|
||||
output_json = collections.OrderedDict()
|
||||
output_json["linex_index"] = unique_id
|
||||
all_features = []
|
||||
for (i, token) in enumerate(feature.tokens):
|
||||
all_layers = []
|
||||
for (j, layer_index) in enumerate(layer_indexes):
|
||||
layer_output = result["layer_output_%d" % j]
|
||||
layers = collections.OrderedDict()
|
||||
layers["index"] = layer_index
|
||||
layers["values"] = [
|
||||
round(float(x), 6) for x in layer_output[i:(i + 1)].flat
|
||||
]
|
||||
all_layers.append(layers)
|
||||
features = collections.OrderedDict()
|
||||
features["token"] = token
|
||||
features["layers"] = all_layers
|
||||
all_features.append(features)
|
||||
output_json["features"] = all_features
|
||||
writer.write(json.dumps(output_json) + "\n")
|
||||
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
|
||||
"w")) as writer:
|
||||
for result in estimator.predict(input_fn, yield_single_examples=True):
|
||||
unique_id = int(result["unique_id"])
|
||||
feature = unique_id_to_feature[unique_id]
|
||||
output_json = collections.OrderedDict()
|
||||
output_json["linex_index"] = unique_id
|
||||
all_features = []
|
||||
for (i, token) in enumerate(feature.tokens):
|
||||
all_layers = []
|
||||
for (j, layer_index) in enumerate(layer_indexes):
|
||||
layer_output = result["layer_output_%d" % j]
|
||||
layers = collections.OrderedDict()
|
||||
layers["index"] = layer_index
|
||||
layers["values"] = [
|
||||
round(float(x), 6) for x in layer_output[i:(i + 1)].flat
|
||||
]
|
||||
all_layers.append(layers)
|
||||
features = collections.OrderedDict()
|
||||
features["token"] = token
|
||||
features["layers"] = all_layers
|
||||
all_features.append(features)
|
||||
output_json["features"] = all_features
|
||||
writer.write(json.dumps(output_json) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
flags.mark_flag_as_required("bert_config_file")
|
||||
flags.mark_flag_as_required("init_checkpoint")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
tf.app.run()
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
flags.mark_flag_as_required("bert_config_file")
|
||||
flags.mark_flag_as_required("init_checkpoint")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
tf.app.run()
|
||||
|
||||
1496
modeling.py
1496
modeling.py
File diff suppressed because it is too large
Load Diff
421
modeling_test.py
421
modeling_test.py
@@ -27,250 +27,249 @@ import tensorflow as tf
|
||||
|
||||
|
||||
class BertModelTest(tf.test.TestCase):
|
||||
class BertModelTester(object):
|
||||
|
||||
class BertModelTester(object):
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
scope=None):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
scope=None):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
def create_model(self):
|
||||
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
|
||||
self.vocab_size)
|
||||
|
||||
def create_model(self):
|
||||
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
|
||||
self.vocab_size)
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = BertModelTest.ids_tensor(
|
||||
[self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = BertModelTest.ids_tensor(
|
||||
[self.batch_size, self.seq_length], vocab_size=2)
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = BertModelTest.ids_tensor(
|
||||
[self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = BertModelTest.ids_tensor(
|
||||
[self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
config = modeling.BertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
config = modeling.BertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
model = modeling.BertModel(
|
||||
config=config,
|
||||
is_training=self.is_training,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
scope=self.scope)
|
||||
|
||||
model = modeling.BertModel(
|
||||
config=config,
|
||||
is_training=self.is_training,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
scope=self.scope)
|
||||
outputs = {
|
||||
"embedding_output": model.get_embedding_output(),
|
||||
"sequence_output": model.get_sequence_output(),
|
||||
"pooled_output": model.get_pooled_output(),
|
||||
"all_encoder_layers": model.get_all_encoder_layers(),
|
||||
}
|
||||
return outputs
|
||||
|
||||
outputs = {
|
||||
"embedding_output": model.get_embedding_output(),
|
||||
"sequence_output": model.get_sequence_output(),
|
||||
"pooled_output": model.get_pooled_output(),
|
||||
"all_encoder_layers": model.get_all_encoder_layers(),
|
||||
}
|
||||
return outputs
|
||||
def check_output(self, result):
|
||||
self.parent.assertAllEqual(
|
||||
result["embedding_output"].shape,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
def check_output(self, result):
|
||||
self.parent.assertAllEqual(
|
||||
result["embedding_output"].shape,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertAllEqual(
|
||||
result["sequence_output"].shape,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
self.parent.assertAllEqual(
|
||||
result["sequence_output"].shape,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertAllEqual(result["pooled_output"].shape,
|
||||
[self.batch_size, self.hidden_size])
|
||||
|
||||
self.parent.assertAllEqual(result["pooled_output"].shape,
|
||||
[self.batch_size, self.hidden_size])
|
||||
def test_default(self):
|
||||
self.run_tester(BertModelTest.BertModelTester(self))
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(BertModelTest.BertModelTester(self))
|
||||
def test_config_to_json_string(self):
|
||||
config = modeling.BertConfig(vocab_size=99, hidden_size=37)
|
||||
obj = json.loads(config.to_json_string())
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["hidden_size"], 37)
|
||||
|
||||
def test_config_to_json_string(self):
|
||||
config = modeling.BertConfig(vocab_size=99, hidden_size=37)
|
||||
obj = json.loads(config.to_json_string())
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["hidden_size"], 37)
|
||||
def run_tester(self, tester):
|
||||
with self.test_session() as sess:
|
||||
ops = tester.create_model()
|
||||
init_op = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
sess.run(init_op)
|
||||
output_result = sess.run(ops)
|
||||
tester.check_output(output_result)
|
||||
|
||||
def run_tester(self, tester):
|
||||
with self.test_session() as sess:
|
||||
ops = tester.create_model()
|
||||
init_op = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
sess.run(init_op)
|
||||
output_result = sess.run(ops)
|
||||
tester.check_output(output_result)
|
||||
self.assert_all_tensors_reachable(sess, [init_op, ops])
|
||||
|
||||
self.assert_all_tensors_reachable(sess, [init_op, ops])
|
||||
@classmethod
|
||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
@classmethod
|
||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
|
||||
|
||||
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
|
||||
def assert_all_tensors_reachable(self, sess, outputs):
|
||||
"""Checks that all the tensors in the graph are reachable from outputs."""
|
||||
graph = sess.graph
|
||||
|
||||
def assert_all_tensors_reachable(self, sess, outputs):
|
||||
"""Checks that all the tensors in the graph are reachable from outputs."""
|
||||
graph = sess.graph
|
||||
ignore_strings = [
|
||||
"^.*/dilation_rate$",
|
||||
"^.*/Tensordot/concat$",
|
||||
"^.*/Tensordot/concat/axis$",
|
||||
"^testing/.*$",
|
||||
]
|
||||
|
||||
ignore_strings = [
|
||||
"^.*/dilation_rate$",
|
||||
"^.*/Tensordot/concat$",
|
||||
"^.*/Tensordot/concat/axis$",
|
||||
"^testing/.*$",
|
||||
]
|
||||
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
||||
|
||||
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
||||
unreachable = self.get_unreachable_ops(graph, outputs)
|
||||
filtered_unreachable = []
|
||||
for x in unreachable:
|
||||
do_ignore = False
|
||||
for r in ignore_regexes:
|
||||
m = r.match(x.name)
|
||||
if m is not None:
|
||||
do_ignore = True
|
||||
if do_ignore:
|
||||
continue
|
||||
filtered_unreachable.append(x)
|
||||
unreachable = filtered_unreachable
|
||||
|
||||
unreachable = self.get_unreachable_ops(graph, outputs)
|
||||
filtered_unreachable = []
|
||||
for x in unreachable:
|
||||
do_ignore = False
|
||||
for r in ignore_regexes:
|
||||
m = r.match(x.name)
|
||||
if m is not None:
|
||||
do_ignore = True
|
||||
if do_ignore:
|
||||
continue
|
||||
filtered_unreachable.append(x)
|
||||
unreachable = filtered_unreachable
|
||||
self.assertEqual(
|
||||
len(unreachable), 0, "The following ops are unreachable: %s" %
|
||||
(" ".join([x.name for x in unreachable])))
|
||||
|
||||
self.assertEqual(
|
||||
len(unreachable), 0, "The following ops are unreachable: %s" %
|
||||
(" ".join([x.name for x in unreachable])))
|
||||
@classmethod
|
||||
def get_unreachable_ops(cls, graph, outputs):
|
||||
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
||||
outputs = cls.flatten_recursive(outputs)
|
||||
output_to_op = collections.defaultdict(list)
|
||||
op_to_all = collections.defaultdict(list)
|
||||
assign_out_to_in = collections.defaultdict(list)
|
||||
|
||||
@classmethod
|
||||
def get_unreachable_ops(cls, graph, outputs):
|
||||
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
||||
outputs = cls.flatten_recursive(outputs)
|
||||
output_to_op = collections.defaultdict(list)
|
||||
op_to_all = collections.defaultdict(list)
|
||||
assign_out_to_in = collections.defaultdict(list)
|
||||
for op in graph.get_operations():
|
||||
for x in op.inputs:
|
||||
op_to_all[op.name].append(x.name)
|
||||
for y in op.outputs:
|
||||
output_to_op[y.name].append(op.name)
|
||||
op_to_all[op.name].append(y.name)
|
||||
if str(op.type) == "Assign":
|
||||
for y in op.outputs:
|
||||
for x in op.inputs:
|
||||
assign_out_to_in[y.name].append(x.name)
|
||||
|
||||
for op in graph.get_operations():
|
||||
for x in op.inputs:
|
||||
op_to_all[op.name].append(x.name)
|
||||
for y in op.outputs:
|
||||
output_to_op[y.name].append(op.name)
|
||||
op_to_all[op.name].append(y.name)
|
||||
if str(op.type) == "Assign":
|
||||
for y in op.outputs:
|
||||
for x in op.inputs:
|
||||
assign_out_to_in[y.name].append(x.name)
|
||||
assign_groups = collections.defaultdict(list)
|
||||
for out_name in assign_out_to_in.keys():
|
||||
name_group = assign_out_to_in[out_name]
|
||||
for n1 in name_group:
|
||||
assign_groups[n1].append(out_name)
|
||||
for n2 in name_group:
|
||||
if n1 != n2:
|
||||
assign_groups[n1].append(n2)
|
||||
|
||||
assign_groups = collections.defaultdict(list)
|
||||
for out_name in assign_out_to_in.keys():
|
||||
name_group = assign_out_to_in[out_name]
|
||||
for n1 in name_group:
|
||||
assign_groups[n1].append(out_name)
|
||||
for n2 in name_group:
|
||||
if n1 != n2:
|
||||
assign_groups[n1].append(n2)
|
||||
seen_tensors = {}
|
||||
stack = [x.name for x in outputs]
|
||||
while stack:
|
||||
name = stack.pop()
|
||||
if name in seen_tensors:
|
||||
continue
|
||||
seen_tensors[name] = True
|
||||
|
||||
seen_tensors = {}
|
||||
stack = [x.name for x in outputs]
|
||||
while stack:
|
||||
name = stack.pop()
|
||||
if name in seen_tensors:
|
||||
continue
|
||||
seen_tensors[name] = True
|
||||
if name in output_to_op:
|
||||
for op_name in output_to_op[name]:
|
||||
if op_name in op_to_all:
|
||||
for input_name in op_to_all[op_name]:
|
||||
if input_name not in stack:
|
||||
stack.append(input_name)
|
||||
|
||||
if name in output_to_op:
|
||||
for op_name in output_to_op[name]:
|
||||
if op_name in op_to_all:
|
||||
for input_name in op_to_all[op_name]:
|
||||
if input_name not in stack:
|
||||
stack.append(input_name)
|
||||
expanded_names = []
|
||||
if name in assign_groups:
|
||||
for assign_name in assign_groups[name]:
|
||||
expanded_names.append(assign_name)
|
||||
|
||||
expanded_names = []
|
||||
if name in assign_groups:
|
||||
for assign_name in assign_groups[name]:
|
||||
expanded_names.append(assign_name)
|
||||
for expanded_name in expanded_names:
|
||||
if expanded_name not in stack:
|
||||
stack.append(expanded_name)
|
||||
|
||||
for expanded_name in expanded_names:
|
||||
if expanded_name not in stack:
|
||||
stack.append(expanded_name)
|
||||
unreachable_ops = []
|
||||
for op in graph.get_operations():
|
||||
is_unreachable = False
|
||||
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
||||
for name in all_names:
|
||||
if name not in seen_tensors:
|
||||
is_unreachable = True
|
||||
if is_unreachable:
|
||||
unreachable_ops.append(op)
|
||||
return unreachable_ops
|
||||
|
||||
unreachable_ops = []
|
||||
for op in graph.get_operations():
|
||||
is_unreachable = False
|
||||
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
||||
for name in all_names:
|
||||
if name not in seen_tensors:
|
||||
is_unreachable = True
|
||||
if is_unreachable:
|
||||
unreachable_ops.append(op)
|
||||
return unreachable_ops
|
||||
@classmethod
|
||||
def flatten_recursive(cls, item):
|
||||
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
||||
output = []
|
||||
if isinstance(item, list):
|
||||
output.extend(item)
|
||||
elif isinstance(item, tuple):
|
||||
output.extend(list(item))
|
||||
elif isinstance(item, dict):
|
||||
for (_, v) in six.iteritems(item):
|
||||
output.append(v)
|
||||
else:
|
||||
return [item]
|
||||
|
||||
@classmethod
|
||||
def flatten_recursive(cls, item):
|
||||
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
||||
output = []
|
||||
if isinstance(item, list):
|
||||
output.extend(item)
|
||||
elif isinstance(item, tuple):
|
||||
output.extend(list(item))
|
||||
elif isinstance(item, dict):
|
||||
for (_, v) in six.iteritems(item):
|
||||
output.append(v)
|
||||
else:
|
||||
return [item]
|
||||
|
||||
flat_output = []
|
||||
for x in output:
|
||||
flat_output.extend(cls.flatten_recursive(x))
|
||||
return flat_output
|
||||
flat_output = []
|
||||
for x in output:
|
||||
flat_output.extend(cls.flatten_recursive(x))
|
||||
return flat_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
tf.test.main()
|
||||
|
||||
236
optimization.py
236
optimization.py
@@ -23,149 +23,149 @@ import tensorflow as tf
|
||||
|
||||
|
||||
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
||||
"""Creates an optimizer training op."""
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
"""Creates an optimizer training op."""
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
||||
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
||||
|
||||
# Implements linear decay of the learning rate.
|
||||
learning_rate = tf.train.polynomial_decay(
|
||||
learning_rate,
|
||||
global_step,
|
||||
num_train_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0,
|
||||
cycle=False)
|
||||
# Implements linear decay of the learning rate.
|
||||
learning_rate = tf.train.polynomial_decay(
|
||||
learning_rate,
|
||||
global_step,
|
||||
num_train_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0,
|
||||
cycle=False)
|
||||
|
||||
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
|
||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||
if num_warmup_steps:
|
||||
global_steps_int = tf.cast(global_step, tf.int32)
|
||||
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
||||
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
|
||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||
if num_warmup_steps:
|
||||
global_steps_int = tf.cast(global_step, tf.int32)
|
||||
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
||||
|
||||
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
||||
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
||||
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
||||
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
||||
|
||||
warmup_percent_done = global_steps_float / warmup_steps_float
|
||||
warmup_learning_rate = init_lr * warmup_percent_done
|
||||
warmup_percent_done = global_steps_float / warmup_steps_float
|
||||
warmup_learning_rate = init_lr * warmup_percent_done
|
||||
|
||||
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
||||
learning_rate = (
|
||||
(1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
|
||||
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
||||
learning_rate = (
|
||||
(1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
|
||||
|
||||
# It is recommended that you use this optimizer for fine tuning, since this
|
||||
# is how the model was trained (note that the Adam m/v variables are NOT
|
||||
# loaded from init_checkpoint.)
|
||||
optimizer = AdamWeightDecayOptimizer(
|
||||
learning_rate=learning_rate,
|
||||
weight_decay_rate=0.01,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
||||
# It is recommended that you use this optimizer for fine tuning, since this
|
||||
# is how the model was trained (note that the Adam m/v variables are NOT
|
||||
# loaded from init_checkpoint.)
|
||||
optimizer = AdamWeightDecayOptimizer(
|
||||
learning_rate=learning_rate,
|
||||
weight_decay_rate=0.01,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
||||
|
||||
if use_tpu:
|
||||
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
||||
if use_tpu:
|
||||
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
||||
|
||||
tvars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, tvars)
|
||||
tvars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, tvars)
|
||||
|
||||
# This is how the model was pre-trained.
|
||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
||||
# This is how the model was pre-trained.
|
||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
||||
|
||||
train_op = optimizer.apply_gradients(
|
||||
zip(grads, tvars), global_step=global_step)
|
||||
train_op = optimizer.apply_gradients(
|
||||
zip(grads, tvars), global_step=global_step)
|
||||
|
||||
new_global_step = global_step + 1
|
||||
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
||||
return train_op
|
||||
new_global_step = global_step + 1
|
||||
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
||||
return train_op
|
||||
|
||||
|
||||
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
||||
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
||||
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
weight_decay_rate=0.0,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=None,
|
||||
name="AdamWeightDecayOptimizer"):
|
||||
"""Constructs a AdamWeightDecayOptimizer."""
|
||||
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
weight_decay_rate=0.0,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=None,
|
||||
name="AdamWeightDecayOptimizer"):
|
||||
"""Constructs a AdamWeightDecayOptimizer."""
|
||||
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay_rate = weight_decay_rate
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
self.epsilon = epsilon
|
||||
self.exclude_from_weight_decay = exclude_from_weight_decay
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay_rate = weight_decay_rate
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
self.epsilon = epsilon
|
||||
self.exclude_from_weight_decay = exclude_from_weight_decay
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
"""See base class."""
|
||||
assignments = []
|
||||
for (grad, param) in grads_and_vars:
|
||||
if grad is None or param is None:
|
||||
continue
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
"""See base class."""
|
||||
assignments = []
|
||||
for (grad, param) in grads_and_vars:
|
||||
if grad is None or param is None:
|
||||
continue
|
||||
|
||||
param_name = self._get_variable_name(param.name)
|
||||
param_name = self._get_variable_name(param.name)
|
||||
|
||||
m = tf.get_variable(
|
||||
name=param_name + "/adam_m",
|
||||
shape=param.shape.as_list(),
|
||||
dtype=tf.float32,
|
||||
trainable=False,
|
||||
initializer=tf.zeros_initializer())
|
||||
v = tf.get_variable(
|
||||
name=param_name + "/adam_v",
|
||||
shape=param.shape.as_list(),
|
||||
dtype=tf.float32,
|
||||
trainable=False,
|
||||
initializer=tf.zeros_initializer())
|
||||
m = tf.get_variable(
|
||||
name=param_name + "/adam_m",
|
||||
shape=param.shape.as_list(),
|
||||
dtype=tf.float32,
|
||||
trainable=False,
|
||||
initializer=tf.zeros_initializer())
|
||||
v = tf.get_variable(
|
||||
name=param_name + "/adam_v",
|
||||
shape=param.shape.as_list(),
|
||||
dtype=tf.float32,
|
||||
trainable=False,
|
||||
initializer=tf.zeros_initializer())
|
||||
|
||||
# Standard Adam update.
|
||||
next_m = (
|
||||
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
||||
next_v = (
|
||||
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
||||
tf.square(grad)))
|
||||
# Standard Adam update.
|
||||
next_m = (
|
||||
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
||||
next_v = (
|
||||
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
||||
tf.square(grad)))
|
||||
|
||||
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
||||
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want ot decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if self._do_use_weight_decay(param_name):
|
||||
update += self.weight_decay_rate * param
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want ot decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if self._do_use_weight_decay(param_name):
|
||||
update += self.weight_decay_rate * param
|
||||
|
||||
update_with_lr = self.learning_rate * update
|
||||
update_with_lr = self.learning_rate * update
|
||||
|
||||
next_param = param - update_with_lr
|
||||
next_param = param - update_with_lr
|
||||
|
||||
assignments.extend(
|
||||
[param.assign(next_param),
|
||||
m.assign(next_m),
|
||||
v.assign(next_v)])
|
||||
return tf.group(*assignments, name=name)
|
||||
assignments.extend(
|
||||
[param.assign(next_param),
|
||||
m.assign(next_m),
|
||||
v.assign(next_v)])
|
||||
return tf.group(*assignments, name=name)
|
||||
|
||||
def _do_use_weight_decay(self, param_name):
|
||||
"""Whether to use L2 weight decay for `param_name`."""
|
||||
if not self.weight_decay_rate:
|
||||
return False
|
||||
if self.exclude_from_weight_decay:
|
||||
for r in self.exclude_from_weight_decay:
|
||||
if re.search(r, param_name) is not None:
|
||||
return False
|
||||
return True
|
||||
def _do_use_weight_decay(self, param_name):
|
||||
"""Whether to use L2 weight decay for `param_name`."""
|
||||
if not self.weight_decay_rate:
|
||||
return False
|
||||
if self.exclude_from_weight_decay:
|
||||
for r in self.exclude_from_weight_decay:
|
||||
if re.search(r, param_name) is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_variable_name(self, param_name):
|
||||
"""Get the variable name from the tensor name."""
|
||||
m = re.match("^(.*):\\d+$", param_name)
|
||||
if m is not None:
|
||||
param_name = m.group(1)
|
||||
return param_name
|
||||
def _get_variable_name(self, param_name):
|
||||
"""Get the variable name from the tensor name."""
|
||||
m = re.match("^(.*):\\d+$", param_name)
|
||||
if m is not None:
|
||||
param_name = m.group(1)
|
||||
return param_name
|
||||
|
||||
@@ -22,27 +22,27 @@ import tensorflow as tf
|
||||
|
||||
class OptimizationTest(tf.test.TestCase):
|
||||
|
||||
def test_adam(self):
|
||||
with self.test_session() as sess:
|
||||
w = tf.get_variable(
|
||||
"w",
|
||||
shape=[3],
|
||||
initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
|
||||
x = tf.constant([0.4, 0.2, -0.5])
|
||||
loss = tf.reduce_mean(tf.square(x - w))
|
||||
tvars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, tvars)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
|
||||
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
|
||||
init_op = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(100):
|
||||
sess.run(train_op)
|
||||
w_np = sess.run(w)
|
||||
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
||||
def test_adam(self):
|
||||
with self.test_session() as sess:
|
||||
w = tf.get_variable(
|
||||
"w",
|
||||
shape=[3],
|
||||
initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
|
||||
x = tf.constant([0.4, 0.2, -0.5])
|
||||
loss = tf.reduce_mean(tf.square(x - w))
|
||||
tvars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, tvars)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
|
||||
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
|
||||
init_op = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(100):
|
||||
sess.run(train_op)
|
||||
w_np = sess.run(w)
|
||||
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
tf.test.main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -109,217 +109,217 @@ flags.DEFINE_integer(
|
||||
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
||||
num_train_steps, num_warmup_steps, use_tpu,
|
||||
use_one_hot_embeddings):
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
"""The `model_fn` for TPUEstimator."""
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
"""The `model_fn` for TPUEstimator."""
|
||||
|
||||
tf.logging.info("*** Features ***")
|
||||
for name in sorted(features.keys()):
|
||||
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
||||
tf.logging.info("*** Features ***")
|
||||
for name in sorted(features.keys()):
|
||||
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
||||
|
||||
input_ids = features["input_ids"]
|
||||
input_mask = features["input_mask"]
|
||||
segment_ids = features["segment_ids"]
|
||||
masked_lm_positions = features["masked_lm_positions"]
|
||||
masked_lm_ids = features["masked_lm_ids"]
|
||||
masked_lm_weights = features["masked_lm_weights"]
|
||||
next_sentence_labels = features["next_sentence_labels"]
|
||||
input_ids = features["input_ids"]
|
||||
input_mask = features["input_mask"]
|
||||
segment_ids = features["segment_ids"]
|
||||
masked_lm_positions = features["masked_lm_positions"]
|
||||
masked_lm_ids = features["masked_lm_ids"]
|
||||
masked_lm_weights = features["masked_lm_weights"]
|
||||
next_sentence_labels = features["next_sentence_labels"]
|
||||
|
||||
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||
|
||||
model = modeling.BertModel(
|
||||
config=bert_config,
|
||||
is_training=is_training,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=segment_ids,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
model = modeling.BertModel(
|
||||
config=bert_config,
|
||||
is_training=is_training,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=segment_ids,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
(masked_lm_loss,
|
||||
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
||||
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
||||
masked_lm_positions, masked_lm_ids, masked_lm_weights)
|
||||
(masked_lm_loss,
|
||||
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
||||
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
||||
masked_lm_positions, masked_lm_ids, masked_lm_weights)
|
||||
|
||||
(next_sentence_loss, next_sentence_example_loss,
|
||||
next_sentence_log_probs) = get_next_sentence_output(
|
||||
bert_config, model.get_pooled_output(), next_sentence_labels)
|
||||
(next_sentence_loss, next_sentence_example_loss,
|
||||
next_sentence_log_probs) = get_next_sentence_output(
|
||||
bert_config, model.get_pooled_output(), next_sentence_labels)
|
||||
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
|
||||
tvars = tf.trainable_variables()
|
||||
tvars = tf.trainable_variables()
|
||||
|
||||
initialized_variable_names = {}
|
||||
scaffold_fn = None
|
||||
if init_checkpoint:
|
||||
(assignment_map,
|
||||
initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
|
||||
tvars, init_checkpoint)
|
||||
if use_tpu:
|
||||
initialized_variable_names = {}
|
||||
scaffold_fn = None
|
||||
if init_checkpoint:
|
||||
(assignment_map,
|
||||
initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
|
||||
tvars, init_checkpoint)
|
||||
if use_tpu:
|
||||
|
||||
def tpu_scaffold():
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
return tf.train.Scaffold()
|
||||
def tpu_scaffold():
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
return tf.train.Scaffold()
|
||||
|
||||
scaffold_fn = tpu_scaffold
|
||||
else:
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
scaffold_fn = tpu_scaffold
|
||||
else:
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
|
||||
tf.logging.info("**** Trainable Variables ****")
|
||||
for var in tvars:
|
||||
init_string = ""
|
||||
if var.name in initialized_variable_names:
|
||||
init_string = ", *INIT_FROM_CKPT*"
|
||||
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
||||
init_string)
|
||||
tf.logging.info("**** Trainable Variables ****")
|
||||
for var in tvars:
|
||||
init_string = ""
|
||||
if var.name in initialized_variable_names:
|
||||
init_string = ", *INIT_FROM_CKPT*"
|
||||
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
||||
init_string)
|
||||
|
||||
output_spec = None
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
train_op = optimization.create_optimizer(
|
||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||
output_spec = None
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
train_op = optimization.create_optimizer(
|
||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
train_op=train_op,
|
||||
scaffold_fn=scaffold_fn)
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
train_op=train_op,
|
||||
scaffold_fn=scaffold_fn)
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
|
||||
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_example_loss,
|
||||
next_sentence_log_probs, next_sentence_labels):
|
||||
"""Computes the loss and accuracy of the model."""
|
||||
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
|
||||
[-1, masked_lm_log_probs.shape[-1]])
|
||||
masked_lm_predictions = tf.argmax(
|
||||
masked_lm_log_probs, axis=-1, output_type=tf.int32)
|
||||
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
||||
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
||||
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
||||
masked_lm_accuracy = tf.metrics.accuracy(
|
||||
labels=masked_lm_ids,
|
||||
predictions=masked_lm_predictions,
|
||||
weights=masked_lm_weights)
|
||||
masked_lm_mean_loss = tf.metrics.mean(
|
||||
values=masked_lm_example_loss, weights=masked_lm_weights)
|
||||
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_example_loss,
|
||||
next_sentence_log_probs, next_sentence_labels):
|
||||
"""Computes the loss and accuracy of the model."""
|
||||
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
|
||||
[-1, masked_lm_log_probs.shape[-1]])
|
||||
masked_lm_predictions = tf.argmax(
|
||||
masked_lm_log_probs, axis=-1, output_type=tf.int32)
|
||||
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
||||
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
||||
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
||||
masked_lm_accuracy = tf.metrics.accuracy(
|
||||
labels=masked_lm_ids,
|
||||
predictions=masked_lm_predictions,
|
||||
weights=masked_lm_weights)
|
||||
masked_lm_mean_loss = tf.metrics.mean(
|
||||
values=masked_lm_example_loss, weights=masked_lm_weights)
|
||||
|
||||
next_sentence_log_probs = tf.reshape(
|
||||
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
|
||||
next_sentence_predictions = tf.argmax(
|
||||
next_sentence_log_probs, axis=-1, output_type=tf.int32)
|
||||
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
||||
next_sentence_accuracy = tf.metrics.accuracy(
|
||||
labels=next_sentence_labels, predictions=next_sentence_predictions)
|
||||
next_sentence_mean_loss = tf.metrics.mean(
|
||||
values=next_sentence_example_loss)
|
||||
next_sentence_log_probs = tf.reshape(
|
||||
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
|
||||
next_sentence_predictions = tf.argmax(
|
||||
next_sentence_log_probs, axis=-1, output_type=tf.int32)
|
||||
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
||||
next_sentence_accuracy = tf.metrics.accuracy(
|
||||
labels=next_sentence_labels, predictions=next_sentence_predictions)
|
||||
next_sentence_mean_loss = tf.metrics.mean(
|
||||
values=next_sentence_example_loss)
|
||||
|
||||
return {
|
||||
"masked_lm_accuracy": masked_lm_accuracy,
|
||||
"masked_lm_loss": masked_lm_mean_loss,
|
||||
"next_sentence_accuracy": next_sentence_accuracy,
|
||||
"next_sentence_loss": next_sentence_mean_loss,
|
||||
}
|
||||
return {
|
||||
"masked_lm_accuracy": masked_lm_accuracy,
|
||||
"masked_lm_loss": masked_lm_mean_loss,
|
||||
"next_sentence_accuracy": next_sentence_accuracy,
|
||||
"next_sentence_loss": next_sentence_mean_loss,
|
||||
}
|
||||
|
||||
eval_metrics = (metric_fn, [
|
||||
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_example_loss,
|
||||
next_sentence_log_probs, next_sentence_labels
|
||||
])
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
eval_metrics=eval_metrics,
|
||||
scaffold_fn=scaffold_fn)
|
||||
else:
|
||||
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
||||
eval_metrics = (metric_fn, [
|
||||
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_example_loss,
|
||||
next_sentence_log_probs, next_sentence_labels
|
||||
])
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
eval_metrics=eval_metrics,
|
||||
scaffold_fn=scaffold_fn)
|
||||
else:
|
||||
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
||||
|
||||
return output_spec
|
||||
return output_spec
|
||||
|
||||
return model_fn
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
||||
label_ids, label_weights):
|
||||
"""Get loss and log probs for the masked LM."""
|
||||
input_tensor = gather_indexes(input_tensor, positions)
|
||||
"""Get loss and log probs for the masked LM."""
|
||||
input_tensor = gather_indexes(input_tensor, positions)
|
||||
|
||||
with tf.variable_scope("cls/predictions"):
|
||||
# We apply one more non-linear transformation before the output layer.
|
||||
# This matrix is not used after pre-training.
|
||||
with tf.variable_scope("transform"):
|
||||
input_tensor = tf.layers.dense(
|
||||
input_tensor,
|
||||
units=bert_config.hidden_size,
|
||||
activation=modeling.get_activation(bert_config.hidden_act),
|
||||
kernel_initializer=modeling.create_initializer(
|
||||
bert_config.initializer_range))
|
||||
input_tensor = modeling.layer_norm(input_tensor)
|
||||
with tf.variable_scope("cls/predictions"):
|
||||
# We apply one more non-linear transformation before the output layer.
|
||||
# This matrix is not used after pre-training.
|
||||
with tf.variable_scope("transform"):
|
||||
input_tensor = tf.layers.dense(
|
||||
input_tensor,
|
||||
units=bert_config.hidden_size,
|
||||
activation=modeling.get_activation(bert_config.hidden_act),
|
||||
kernel_initializer=modeling.create_initializer(
|
||||
bert_config.initializer_range))
|
||||
input_tensor = modeling.layer_norm(input_tensor)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
output_bias = tf.get_variable(
|
||||
"output_bias",
|
||||
shape=[bert_config.vocab_size],
|
||||
initializer=tf.zeros_initializer())
|
||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
output_bias = tf.get_variable(
|
||||
"output_bias",
|
||||
shape=[bert_config.vocab_size],
|
||||
initializer=tf.zeros_initializer())
|
||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
|
||||
label_ids = tf.reshape(label_ids, [-1])
|
||||
label_weights = tf.reshape(label_weights, [-1])
|
||||
label_ids = tf.reshape(label_ids, [-1])
|
||||
label_weights = tf.reshape(label_weights, [-1])
|
||||
|
||||
one_hot_labels = tf.one_hot(
|
||||
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
|
||||
one_hot_labels = tf.one_hot(
|
||||
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
|
||||
|
||||
# The `positions` tensor might be zero-padded (if the sequence is too
|
||||
# short to have the maximum number of predictions). The `label_weights`
|
||||
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
||||
# padding predictions.
|
||||
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
||||
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
||||
denominator = tf.reduce_sum(label_weights) + 1e-5
|
||||
loss = numerator / denominator
|
||||
# The `positions` tensor might be zero-padded (if the sequence is too
|
||||
# short to have the maximum number of predictions). The `label_weights`
|
||||
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
||||
# padding predictions.
|
||||
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
||||
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
||||
denominator = tf.reduce_sum(label_weights) + 1e-5
|
||||
loss = numerator / denominator
|
||||
|
||||
return (loss, per_example_loss, log_probs)
|
||||
|
||||
|
||||
def get_next_sentence_output(bert_config, input_tensor, labels):
|
||||
"""Get loss and log probs for the next sentence prediction."""
|
||||
|
||||
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
||||
# "random sentence". This weight matrix is not used after pre-training.
|
||||
with tf.variable_scope("cls/seq_relationship"):
|
||||
output_weights = tf.get_variable(
|
||||
"output_weights",
|
||||
shape=[2, bert_config.hidden_size],
|
||||
initializer=modeling.create_initializer(bert_config.initializer_range))
|
||||
output_bias = tf.get_variable(
|
||||
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
||||
|
||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
labels = tf.reshape(labels, [-1])
|
||||
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
||||
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||
loss = tf.reduce_mean(per_example_loss)
|
||||
return (loss, per_example_loss, log_probs)
|
||||
|
||||
|
||||
def gather_indexes(sequence_tensor, positions):
|
||||
"""Gathers the vectors at the specific positions over a minibatch."""
|
||||
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
||||
batch_size = sequence_shape[0]
|
||||
seq_length = sequence_shape[1]
|
||||
width = sequence_shape[2]
|
||||
def get_next_sentence_output(bert_config, input_tensor, labels):
|
||||
"""Get loss and log probs for the next sentence prediction."""
|
||||
|
||||
flat_offsets = tf.reshape(
|
||||
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
||||
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
||||
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
||||
[batch_size * seq_length, width])
|
||||
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
||||
return output_tensor
|
||||
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
||||
# "random sentence". This weight matrix is not used after pre-training.
|
||||
with tf.variable_scope("cls/seq_relationship"):
|
||||
output_weights = tf.get_variable(
|
||||
"output_weights",
|
||||
shape=[2, bert_config.hidden_size],
|
||||
initializer=modeling.create_initializer(bert_config.initializer_range))
|
||||
output_bias = tf.get_variable(
|
||||
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
||||
|
||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
labels = tf.reshape(labels, [-1])
|
||||
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
||||
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||
loss = tf.reduce_mean(per_example_loss)
|
||||
return (loss, per_example_loss, log_probs)
|
||||
|
||||
|
||||
def gather_indexes(sequence_tensor, positions):
|
||||
"""Gathers the vectors at the specific positions over a minibatch."""
|
||||
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
||||
batch_size = sequence_shape[0]
|
||||
seq_length = sequence_shape[1]
|
||||
width = sequence_shape[2]
|
||||
|
||||
flat_offsets = tf.reshape(
|
||||
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
||||
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
||||
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
||||
[batch_size * seq_length, width])
|
||||
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def input_fn_builder(input_files,
|
||||
@@ -327,168 +327,168 @@ def input_fn_builder(input_files,
|
||||
max_predictions_per_seq,
|
||||
is_training,
|
||||
num_cpu_threads=4):
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
|
||||
name_to_features = {
|
||||
"input_ids":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"input_mask":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"segment_ids":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"masked_lm_positions":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
"masked_lm_ids":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
"masked_lm_weights":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||
"next_sentence_labels":
|
||||
tf.FixedLenFeature([1], tf.int64),
|
||||
}
|
||||
name_to_features = {
|
||||
"input_ids":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"input_mask":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"segment_ids":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"masked_lm_positions":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
"masked_lm_ids":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
"masked_lm_weights":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||
"next_sentence_labels":
|
||||
tf.FixedLenFeature([1], tf.int64),
|
||||
}
|
||||
|
||||
# For training, we want a lot of parallel reading and shuffling.
|
||||
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||
if is_training:
|
||||
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
||||
d = d.repeat()
|
||||
d = d.shuffle(buffer_size=len(input_files))
|
||||
# For training, we want a lot of parallel reading and shuffling.
|
||||
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||
if is_training:
|
||||
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
||||
d = d.repeat()
|
||||
d = d.shuffle(buffer_size=len(input_files))
|
||||
|
||||
# `cycle_length` is the number of parallel files that get read.
|
||||
cycle_length = min(num_cpu_threads, len(input_files))
|
||||
# `cycle_length` is the number of parallel files that get read.
|
||||
cycle_length = min(num_cpu_threads, len(input_files))
|
||||
|
||||
# `sloppy` mode means that the interleaving is not exact. This adds
|
||||
# even more randomness to the training pipeline.
|
||||
d = d.apply(
|
||||
tf.contrib.data.parallel_interleave(
|
||||
tf.data.TFRecordDataset,
|
||||
sloppy=is_training,
|
||||
cycle_length=cycle_length))
|
||||
d = d.shuffle(buffer_size=100)
|
||||
else:
|
||||
d = tf.data.TFRecordDataset(input_files)
|
||||
# Since we evaluate for a fixed number of steps we don't want to encounter
|
||||
# out-of-range exceptions.
|
||||
d = d.repeat()
|
||||
# `sloppy` mode means that the interleaving is not exact. This adds
|
||||
# even more randomness to the training pipeline.
|
||||
d = d.apply(
|
||||
tf.contrib.data.parallel_interleave(
|
||||
tf.data.TFRecordDataset,
|
||||
sloppy=is_training,
|
||||
cycle_length=cycle_length))
|
||||
d = d.shuffle(buffer_size=100)
|
||||
else:
|
||||
d = tf.data.TFRecordDataset(input_files)
|
||||
# Since we evaluate for a fixed number of steps we don't want to encounter
|
||||
# out-of-range exceptions.
|
||||
d = d.repeat()
|
||||
|
||||
# We must `drop_remainder` on training because the TPU requires fixed
|
||||
# size dimensions. For eval, we assume we are evaling on the CPU or GPU
|
||||
# and we *don"t* want to drop the remainder, otherwise we wont cover
|
||||
# every sample.
|
||||
d = d.apply(
|
||||
tf.contrib.data.map_and_batch(
|
||||
lambda record: _decode_record(record, name_to_features),
|
||||
batch_size=batch_size,
|
||||
num_parallel_batches=num_cpu_threads,
|
||||
drop_remainder=True))
|
||||
return d
|
||||
# We must `drop_remainder` on training because the TPU requires fixed
|
||||
# size dimensions. For eval, we assume we are evaling on the CPU or GPU
|
||||
# and we *don"t* want to drop the remainder, otherwise we wont cover
|
||||
# every sample.
|
||||
d = d.apply(
|
||||
tf.contrib.data.map_and_batch(
|
||||
lambda record: _decode_record(record, name_to_features),
|
||||
batch_size=batch_size,
|
||||
num_parallel_batches=num_cpu_threads,
|
||||
drop_remainder=True))
|
||||
return d
|
||||
|
||||
return input_fn
|
||||
return input_fn
|
||||
|
||||
|
||||
def _decode_record(record, name_to_features):
|
||||
"""Decodes a record to a TensorFlow example."""
|
||||
example = tf.parse_single_example(record, name_to_features)
|
||||
"""Decodes a record to a TensorFlow example."""
|
||||
example = tf.parse_single_example(record, name_to_features)
|
||||
|
||||
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||
# So cast all int64 to int32.
|
||||
for name in list(example.keys()):
|
||||
t = example[name]
|
||||
if t.dtype == tf.int64:
|
||||
t = tf.to_int32(t)
|
||||
example[name] = t
|
||||
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||
# So cast all int64 to int32.
|
||||
for name in list(example.keys()):
|
||||
t = example[name]
|
||||
if t.dtype == tf.int64:
|
||||
t = tf.to_int32(t)
|
||||
example[name] = t
|
||||
|
||||
return example
|
||||
return example
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
if not FLAGS.do_train and not FLAGS.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
if not FLAGS.do_train and not FLAGS.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
|
||||
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
|
||||
tf.logging.info("*** Input Files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s" % input_file)
|
||||
tf.logging.info("*** Input Files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s" % input_file)
|
||||
|
||||
tpu_cluster_resolver = None
|
||||
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||
tpu_cluster_resolver = None
|
||||
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
master=FLAGS.master,
|
||||
model_dir=FLAGS.output_dir,
|
||||
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||
num_shards=FLAGS.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
master=FLAGS.master,
|
||||
model_dir=FLAGS.output_dir,
|
||||
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||
num_shards=FLAGS.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
init_checkpoint=FLAGS.init_checkpoint,
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
num_train_steps=FLAGS.num_train_steps,
|
||||
num_warmup_steps=FLAGS.num_warmup_steps,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
use_one_hot_embeddings=FLAGS.use_tpu)
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
init_checkpoint=FLAGS.init_checkpoint,
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
num_train_steps=FLAGS.num_train_steps,
|
||||
num_warmup_steps=FLAGS.num_warmup_steps,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
use_one_hot_embeddings=FLAGS.use_tpu)
|
||||
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
train_batch_size=FLAGS.train_batch_size,
|
||||
eval_batch_size=FLAGS.eval_batch_size)
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
train_batch_size=FLAGS.train_batch_size,
|
||||
eval_batch_size=FLAGS.eval_batch_size)
|
||||
|
||||
if FLAGS.do_train:
|
||||
tf.logging.info("***** Running training *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||
train_input_fn = input_fn_builder(
|
||||
input_files=input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=True)
|
||||
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
||||
if FLAGS.do_train:
|
||||
tf.logging.info("***** Running training *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||
train_input_fn = input_fn_builder(
|
||||
input_files=input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=True)
|
||||
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
||||
|
||||
if FLAGS.do_eval:
|
||||
tf.logging.info("***** Running evaluation *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||
if FLAGS.do_eval:
|
||||
tf.logging.info("***** Running evaluation *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||
|
||||
eval_input_fn = input_fn_builder(
|
||||
input_files=input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=False)
|
||||
eval_input_fn = input_fn_builder(
|
||||
input_files=input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=False)
|
||||
|
||||
result = estimator.evaluate(
|
||||
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
||||
result = estimator.evaluate(
|
||||
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
||||
|
||||
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
||||
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
||||
tf.logging.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
||||
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
||||
tf.logging.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("bert_config_file")
|
||||
flags.mark_flag_as_required("output_dir")
|
||||
tf.app.run()
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("bert_config_file")
|
||||
flags.mark_flag_as_required("output_dir")
|
||||
tf.app.run()
|
||||
|
||||
1606
run_squad.py
1606
run_squad.py
File diff suppressed because it is too large
Load Diff
416
tokenization.py
416
tokenization.py
@@ -25,268 +25,268 @@ import tensorflow as tf
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with tf.gfile.GFile(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with tf.gfile.GFile(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(vocab[token])
|
||||
return ids
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(vocab[token])
|
||||
return ids
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_tokens_to_ids(self.vocab, tokens)
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_tokens_to_ids(self.vocab, tokens)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenziation."""
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -25,101 +25,101 @@ import tensorflow as tf
|
||||
|
||||
class TokenizationTest(tf.test.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
vocab_file = vocab_writer.name
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file)
|
||||
os.unlink(vocab_file)
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file)
|
||||
os.unlink(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
self.assertAllEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
def test_basic_tokenizer_lower(self):
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
||||
def test_basic_tokenizer_lower(self):
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_basic_tokenizer_no_lower(self):
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
||||
def test_basic_tokenizer_no_lower(self):
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
def test_wordpiece_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing"
|
||||
]
|
||||
def test_wordpiece_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing"
|
||||
]
|
||||
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
||||
|
||||
self.assertAllEqual(tokenizer.tokenize(""), [])
|
||||
self.assertAllEqual(tokenizer.tokenize(""), [])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwanted running"),
|
||||
["un", "##want", "##ed", "runn", "##ing"])
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwanted running"),
|
||||
["un", "##want", "##ed", "runn", "##ing"])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||
|
||||
def test_convert_tokens_to_ids(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing"
|
||||
]
|
||||
def test_convert_tokens_to_ids(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing"
|
||||
]
|
||||
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenization.convert_tokens_to_ids(
|
||||
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
||||
self.assertAllEqual(
|
||||
tokenization.convert_tokens_to_ids(
|
||||
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
||||
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(tokenization._is_whitespace(u" "))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(tokenization._is_whitespace(u" "))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
||||
|
||||
self.assertFalse(tokenization._is_whitespace(u"A"))
|
||||
self.assertFalse(tokenization._is_whitespace(u"-"))
|
||||
self.assertFalse(tokenization._is_whitespace(u"A"))
|
||||
self.assertFalse(tokenization._is_whitespace(u"-"))
|
||||
|
||||
def test_is_control(self):
|
||||
self.assertTrue(tokenization._is_control(u"\u0005"))
|
||||
def test_is_control(self):
|
||||
self.assertTrue(tokenization._is_control(u"\u0005"))
|
||||
|
||||
self.assertFalse(tokenization._is_control(u"A"))
|
||||
self.assertFalse(tokenization._is_control(u" "))
|
||||
self.assertFalse(tokenization._is_control(u"\t"))
|
||||
self.assertFalse(tokenization._is_control(u"\r"))
|
||||
self.assertFalse(tokenization._is_control(u"A"))
|
||||
self.assertFalse(tokenization._is_control(u" "))
|
||||
self.assertFalse(tokenization._is_control(u"\t"))
|
||||
self.assertFalse(tokenization._is_control(u"\r"))
|
||||
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(tokenization._is_punctuation(u"-"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"$"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"`"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"."))
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(tokenization._is_punctuation(u"-"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"$"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"`"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"."))
|
||||
|
||||
self.assertFalse(tokenization._is_punctuation(u"A"))
|
||||
self.assertFalse(tokenization._is_punctuation(u" "))
|
||||
self.assertFalse(tokenization._is_punctuation(u"A"))
|
||||
self.assertFalse(tokenization._is_punctuation(u" "))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
tf.test.main()
|
||||
|
||||
Reference in New Issue
Block a user