Convert indentation from 2 spaces to 4 spaces
This commit is contained in:
@@ -63,379 +63,379 @@ flags.DEFINE_float(
|
|||||||
|
|
||||||
|
|
||||||
class TrainingInstance(object):
|
class TrainingInstance(object):
|
||||||
"""A single training instance (sentence pair)."""
|
"""A single training instance (sentence pair)."""
|
||||||
|
|
||||||
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||||
is_random_next):
|
is_random_next):
|
||||||
self.tokens = tokens
|
self.tokens = tokens
|
||||||
self.segment_ids = segment_ids
|
self.segment_ids = segment_ids
|
||||||
self.is_random_next = is_random_next
|
self.is_random_next = is_random_next
|
||||||
self.masked_lm_positions = masked_lm_positions
|
self.masked_lm_positions = masked_lm_positions
|
||||||
self.masked_lm_labels = masked_lm_labels
|
self.masked_lm_labels = masked_lm_labels
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
s = ""
|
s = ""
|
||||||
s += "tokens: %s\n" % (" ".join(
|
s += "tokens: %s\n" % (" ".join(
|
||||||
[tokenization.printable_text(x) for x in self.tokens]))
|
[tokenization.printable_text(x) for x in self.tokens]))
|
||||||
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||||
s += "is_random_next: %s\n" % self.is_random_next
|
s += "is_random_next: %s\n" % self.is_random_next
|
||||||
s += "masked_lm_positions: %s\n" % (" ".join(
|
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||||
[str(x) for x in self.masked_lm_positions]))
|
[str(x) for x in self.masked_lm_positions]))
|
||||||
s += "masked_lm_labels: %s\n" % (" ".join(
|
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||||
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||||
s += "\n"
|
s += "\n"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||||
max_predictions_per_seq, output_files):
|
max_predictions_per_seq, output_files):
|
||||||
"""Create TF example files from `TrainingInstance`s."""
|
"""Create TF example files from `TrainingInstance`s."""
|
||||||
writers = []
|
writers = []
|
||||||
for output_file in output_files:
|
for output_file in output_files:
|
||||||
writers.append(tf.python_io.TFRecordWriter(output_file))
|
writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||||
|
|
||||||
writer_index = 0
|
writer_index = 0
|
||||||
|
|
||||||
total_written = 0
|
total_written = 0
|
||||||
for (inst_index, instance) in enumerate(instances):
|
for (inst_index, instance) in enumerate(instances):
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||||
input_mask = [1] * len(input_ids)
|
input_mask = [1] * len(input_ids)
|
||||||
segment_ids = list(instance.segment_ids)
|
segment_ids = list(instance.segment_ids)
|
||||||
assert len(input_ids) <= max_seq_length
|
assert len(input_ids) <= max_seq_length
|
||||||
|
|
||||||
while len(input_ids) < max_seq_length:
|
while len(input_ids) < max_seq_length:
|
||||||
input_ids.append(0)
|
input_ids.append(0)
|
||||||
input_mask.append(0)
|
input_mask.append(0)
|
||||||
segment_ids.append(0)
|
segment_ids.append(0)
|
||||||
|
|
||||||
assert len(input_ids) == max_seq_length
|
assert len(input_ids) == max_seq_length
|
||||||
assert len(input_mask) == max_seq_length
|
assert len(input_mask) == max_seq_length
|
||||||
assert len(segment_ids) == max_seq_length
|
assert len(segment_ids) == max_seq_length
|
||||||
|
|
||||||
masked_lm_positions = list(instance.masked_lm_positions)
|
masked_lm_positions = list(instance.masked_lm_positions)
|
||||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||||
|
|
||||||
while len(masked_lm_positions) < max_predictions_per_seq:
|
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||||
masked_lm_positions.append(0)
|
masked_lm_positions.append(0)
|
||||||
masked_lm_ids.append(0)
|
masked_lm_ids.append(0)
|
||||||
masked_lm_weights.append(0.0)
|
masked_lm_weights.append(0.0)
|
||||||
|
|
||||||
next_sentence_label = 1 if instance.is_random_next else 0
|
next_sentence_label = 1 if instance.is_random_next else 0
|
||||||
|
|
||||||
features = collections.OrderedDict()
|
features = collections.OrderedDict()
|
||||||
features["input_ids"] = create_int_feature(input_ids)
|
features["input_ids"] = create_int_feature(input_ids)
|
||||||
features["input_mask"] = create_int_feature(input_mask)
|
features["input_mask"] = create_int_feature(input_mask)
|
||||||
features["segment_ids"] = create_int_feature(segment_ids)
|
features["segment_ids"] = create_int_feature(segment_ids)
|
||||||
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||||
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||||
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||||
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||||
|
|
||||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||||
|
|
||||||
writers[writer_index].write(tf_example.SerializeToString())
|
writers[writer_index].write(tf_example.SerializeToString())
|
||||||
writer_index = (writer_index + 1) % len(writers)
|
writer_index = (writer_index + 1) % len(writers)
|
||||||
|
|
||||||
total_written += 1
|
total_written += 1
|
||||||
|
|
||||||
if inst_index < 20:
|
if inst_index < 20:
|
||||||
tf.logging.info("*** Example ***")
|
tf.logging.info("*** Example ***")
|
||||||
tf.logging.info("tokens: %s" % " ".join(
|
tf.logging.info("tokens: %s" % " ".join(
|
||||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||||
|
|
||||||
for feature_name in features.keys():
|
for feature_name in features.keys():
|
||||||
feature = features[feature_name]
|
feature = features[feature_name]
|
||||||
values = []
|
values = []
|
||||||
if feature.int64_list.value:
|
if feature.int64_list.value:
|
||||||
values = feature.int64_list.value
|
values = feature.int64_list.value
|
||||||
elif feature.float_list.value:
|
elif feature.float_list.value:
|
||||||
values = feature.float_list.value
|
values = feature.float_list.value
|
||||||
tf.logging.info(
|
tf.logging.info(
|
||||||
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||||
|
|
||||||
for writer in writers:
|
for writer in writers:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
tf.logging.info("Wrote %d total instances", total_written)
|
tf.logging.info("Wrote %d total instances", total_written)
|
||||||
|
|
||||||
|
|
||||||
def create_int_feature(values):
|
def create_int_feature(values):
|
||||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
|
|
||||||
def create_float_feature(values):
|
def create_float_feature(values):
|
||||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
|
|
||||||
def create_training_instances(input_files, tokenizer, max_seq_length,
|
def create_training_instances(input_files, tokenizer, max_seq_length,
|
||||||
dupe_factor, short_seq_prob, masked_lm_prob,
|
dupe_factor, short_seq_prob, masked_lm_prob,
|
||||||
max_predictions_per_seq, rng):
|
max_predictions_per_seq, rng):
|
||||||
"""Create `TrainingInstance`s from raw text."""
|
"""Create `TrainingInstance`s from raw text."""
|
||||||
all_documents = [[]]
|
all_documents = [[]]
|
||||||
|
|
||||||
# Input file format:
|
# Input file format:
|
||||||
# (1) One sentence per line. These should ideally be actual sentences, not
|
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||||
# entire paragraphs or arbitrary spans of text. (Because we use the
|
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||||
# sentence boundaries for the "next sentence prediction" task).
|
# sentence boundaries for the "next sentence prediction" task).
|
||||||
# (2) Blank lines between documents. Document boundaries are needed so
|
# (2) Blank lines between documents. Document boundaries are needed so
|
||||||
# that the "next sentence prediction" task doesn't span between documents.
|
# that the "next sentence prediction" task doesn't span between documents.
|
||||||
for input_file in input_files:
|
for input_file in input_files:
|
||||||
with tf.gfile.GFile(input_file, "r") as reader:
|
with tf.gfile.GFile(input_file, "r") as reader:
|
||||||
while True:
|
while True:
|
||||||
line = tokenization.convert_to_unicode(reader.readline())
|
line = tokenization.convert_to_unicode(reader.readline())
|
||||||
if not line:
|
if not line:
|
||||||
break
|
break
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
|
||||||
# Empty lines are used as document delimiters
|
# Empty lines are used as document delimiters
|
||||||
if not line:
|
if not line:
|
||||||
all_documents.append([])
|
all_documents.append([])
|
||||||
tokens = tokenizer.tokenize(line)
|
tokens = tokenizer.tokenize(line)
|
||||||
if tokens:
|
if tokens:
|
||||||
all_documents[-1].append(tokens)
|
all_documents[-1].append(tokens)
|
||||||
|
|
||||||
# Remove empty documents
|
# Remove empty documents
|
||||||
all_documents = [x for x in all_documents if x]
|
all_documents = [x for x in all_documents if x]
|
||||||
rng.shuffle(all_documents)
|
rng.shuffle(all_documents)
|
||||||
|
|
||||||
vocab_words = list(tokenizer.vocab.keys())
|
vocab_words = list(tokenizer.vocab.keys())
|
||||||
instances = []
|
instances = []
|
||||||
for _ in range(dupe_factor):
|
for _ in range(dupe_factor):
|
||||||
for document_index in range(len(all_documents)):
|
for document_index in range(len(all_documents)):
|
||||||
instances.extend(
|
instances.extend(
|
||||||
create_instances_from_document(
|
create_instances_from_document(
|
||||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||||
|
|
||||||
rng.shuffle(instances)
|
rng.shuffle(instances)
|
||||||
return instances
|
return instances
|
||||||
|
|
||||||
|
|
||||||
def create_instances_from_document(
|
def create_instances_from_document(
|
||||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||||
"""Creates `TrainingInstance`s for a single document."""
|
"""Creates `TrainingInstance`s for a single document."""
|
||||||
document = all_documents[document_index]
|
document = all_documents[document_index]
|
||||||
|
|
||||||
# Account for [CLS], [SEP], [SEP]
|
# Account for [CLS], [SEP], [SEP]
|
||||||
max_num_tokens = max_seq_length - 3
|
max_num_tokens = max_seq_length - 3
|
||||||
|
|
||||||
# We *usually* want to fill up the entire sequence since we are padding
|
# We *usually* want to fill up the entire sequence since we are padding
|
||||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||||
# computation. However, we *sometimes*
|
# computation. However, we *sometimes*
|
||||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||||
# The `target_seq_length` is just a rough target however, whereas
|
# The `target_seq_length` is just a rough target however, whereas
|
||||||
# `max_seq_length` is a hard limit.
|
# `max_seq_length` is a hard limit.
|
||||||
target_seq_length = max_num_tokens
|
target_seq_length = max_num_tokens
|
||||||
if rng.random() < short_seq_prob:
|
if rng.random() < short_seq_prob:
|
||||||
target_seq_length = rng.randint(2, max_num_tokens)
|
target_seq_length = rng.randint(2, max_num_tokens)
|
||||||
|
|
||||||
# We DON'T just concatenate all of the tokens from a document into a long
|
# We DON'T just concatenate all of the tokens from a document into a long
|
||||||
# sequence and choose an arbitrary split point because this would make the
|
# sequence and choose an arbitrary split point because this would make the
|
||||||
# next sentence prediction task too easy. Instead, we split the input into
|
# next sentence prediction task too easy. Instead, we split the input into
|
||||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||||
# input.
|
# input.
|
||||||
instances = []
|
instances = []
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(document):
|
while i < len(document):
|
||||||
segment = document[i]
|
segment = document[i]
|
||||||
current_chunk.append(segment)
|
current_chunk.append(segment)
|
||||||
current_length += len(segment)
|
current_length += len(segment)
|
||||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||||
# (first) sentence.
|
# (first) sentence.
|
||||||
a_end = 1
|
a_end = 1
|
||||||
if len(current_chunk) >= 2:
|
if len(current_chunk) >= 2:
|
||||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||||
|
|
||||||
tokens_a = []
|
tokens_a = []
|
||||||
for j in range(a_end):
|
for j in range(a_end):
|
||||||
tokens_a.extend(current_chunk[j])
|
tokens_a.extend(current_chunk[j])
|
||||||
|
|
||||||
tokens_b = []
|
tokens_b = []
|
||||||
# Random next
|
# Random next
|
||||||
is_random_next = False
|
is_random_next = False
|
||||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||||
is_random_next = True
|
is_random_next = True
|
||||||
target_b_length = target_seq_length - len(tokens_a)
|
target_b_length = target_seq_length - len(tokens_a)
|
||||||
|
|
||||||
# This should rarely go for more than one iteration for large
|
# This should rarely go for more than one iteration for large
|
||||||
# corpora. However, just to be careful, we try to make sure that
|
# corpora. However, just to be careful, we try to make sure that
|
||||||
# the random document is not the same as the document
|
# the random document is not the same as the document
|
||||||
# we're processing.
|
# we're processing.
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
random_document_index = rng.randint(0, len(all_documents) - 1)
|
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||||
if random_document_index != document_index:
|
if random_document_index != document_index:
|
||||||
break
|
break
|
||||||
|
|
||||||
random_document = all_documents[random_document_index]
|
random_document = all_documents[random_document_index]
|
||||||
random_start = rng.randint(0, len(random_document) - 1)
|
random_start = rng.randint(0, len(random_document) - 1)
|
||||||
for j in range(random_start, len(random_document)):
|
for j in range(random_start, len(random_document)):
|
||||||
tokens_b.extend(random_document[j])
|
tokens_b.extend(random_document[j])
|
||||||
if len(tokens_b) >= target_b_length:
|
if len(tokens_b) >= target_b_length:
|
||||||
break
|
break
|
||||||
# We didn't actually use these segments so we "put them back" so
|
# We didn't actually use these segments so we "put them back" so
|
||||||
# they don't go to waste.
|
# they don't go to waste.
|
||||||
num_unused_segments = len(current_chunk) - a_end
|
num_unused_segments = len(current_chunk) - a_end
|
||||||
i -= num_unused_segments
|
i -= num_unused_segments
|
||||||
# Actual next
|
# Actual next
|
||||||
else:
|
else:
|
||||||
is_random_next = False
|
is_random_next = False
|
||||||
for j in range(a_end, len(current_chunk)):
|
for j in range(a_end, len(current_chunk)):
|
||||||
tokens_b.extend(current_chunk[j])
|
tokens_b.extend(current_chunk[j])
|
||||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||||
|
|
||||||
assert len(tokens_a) >= 1
|
assert len(tokens_a) >= 1
|
||||||
assert len(tokens_b) >= 1
|
assert len(tokens_b) >= 1
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
segment_ids = []
|
segment_ids = []
|
||||||
tokens.append("[CLS]")
|
tokens.append("[CLS]")
|
||||||
segment_ids.append(0)
|
segment_ids.append(0)
|
||||||
for token in tokens_a:
|
for token in tokens_a:
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
segment_ids.append(0)
|
segment_ids.append(0)
|
||||||
|
|
||||||
tokens.append("[SEP]")
|
tokens.append("[SEP]")
|
||||||
segment_ids.append(0)
|
segment_ids.append(0)
|
||||||
|
|
||||||
for token in tokens_b:
|
for token in tokens_b:
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
segment_ids.append(1)
|
segment_ids.append(1)
|
||||||
tokens.append("[SEP]")
|
tokens.append("[SEP]")
|
||||||
segment_ids.append(1)
|
segment_ids.append(1)
|
||||||
|
|
||||||
(tokens, masked_lm_positions,
|
(tokens, masked_lm_positions,
|
||||||
masked_lm_labels) = create_masked_lm_predictions(
|
masked_lm_labels) = create_masked_lm_predictions(
|
||||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||||
instance = TrainingInstance(
|
instance = TrainingInstance(
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
segment_ids=segment_ids,
|
segment_ids=segment_ids,
|
||||||
is_random_next=is_random_next,
|
is_random_next=is_random_next,
|
||||||
masked_lm_positions=masked_lm_positions,
|
masked_lm_positions=masked_lm_positions,
|
||||||
masked_lm_labels=masked_lm_labels)
|
masked_lm_labels=masked_lm_labels)
|
||||||
instances.append(instance)
|
instances.append(instance)
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return instances
|
return instances
|
||||||
|
|
||||||
|
|
||||||
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||||
max_predictions_per_seq, vocab_words, rng):
|
max_predictions_per_seq, vocab_words, rng):
|
||||||
"""Creates the predictis for the masked LM objective."""
|
"""Creates the predictis for the masked LM objective."""
|
||||||
|
|
||||||
cand_indexes = []
|
cand_indexes = []
|
||||||
for (i, token) in enumerate(tokens):
|
for (i, token) in enumerate(tokens):
|
||||||
if token == "[CLS]" or token == "[SEP]":
|
if token == "[CLS]" or token == "[SEP]":
|
||||||
continue
|
continue
|
||||||
cand_indexes.append(i)
|
cand_indexes.append(i)
|
||||||
|
|
||||||
rng.shuffle(cand_indexes)
|
rng.shuffle(cand_indexes)
|
||||||
|
|
||||||
output_tokens = list(tokens)
|
output_tokens = list(tokens)
|
||||||
|
|
||||||
masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
|
masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
|
||||||
|
|
||||||
num_to_predict = min(max_predictions_per_seq,
|
num_to_predict = min(max_predictions_per_seq,
|
||||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||||
|
|
||||||
masked_lms = []
|
masked_lms = []
|
||||||
covered_indexes = set()
|
covered_indexes = set()
|
||||||
for index in cand_indexes:
|
for index in cand_indexes:
|
||||||
if len(masked_lms) >= num_to_predict:
|
if len(masked_lms) >= num_to_predict:
|
||||||
break
|
break
|
||||||
if index in covered_indexes:
|
if index in covered_indexes:
|
||||||
continue
|
continue
|
||||||
covered_indexes.add(index)
|
covered_indexes.add(index)
|
||||||
|
|
||||||
masked_token = None
|
masked_token = None
|
||||||
# 80% of the time, replace with [MASK]
|
# 80% of the time, replace with [MASK]
|
||||||
if rng.random() < 0.8:
|
if rng.random() < 0.8:
|
||||||
masked_token = "[MASK]"
|
masked_token = "[MASK]"
|
||||||
else:
|
else:
|
||||||
# 10% of the time, keep original
|
# 10% of the time, keep original
|
||||||
if rng.random() < 0.5:
|
if rng.random() < 0.5:
|
||||||
masked_token = tokens[index]
|
masked_token = tokens[index]
|
||||||
# 10% of the time, replace with random word
|
# 10% of the time, replace with random word
|
||||||
else:
|
else:
|
||||||
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||||
|
|
||||||
output_tokens[index] = masked_token
|
output_tokens[index] = masked_token
|
||||||
|
|
||||||
masked_lms.append(masked_lm(index=index, label=tokens[index]))
|
masked_lms.append(masked_lm(index=index, label=tokens[index]))
|
||||||
|
|
||||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
|
|
||||||
masked_lm_positions = []
|
masked_lm_positions = []
|
||||||
masked_lm_labels = []
|
masked_lm_labels = []
|
||||||
for p in masked_lms:
|
for p in masked_lms:
|
||||||
masked_lm_positions.append(p.index)
|
masked_lm_positions.append(p.index)
|
||||||
masked_lm_labels.append(p.label)
|
masked_lm_labels.append(p.label)
|
||||||
|
|
||||||
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||||
|
|
||||||
|
|
||||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
||||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||||
while True:
|
while True:
|
||||||
total_length = len(tokens_a) + len(tokens_b)
|
total_length = len(tokens_a) + len(tokens_b)
|
||||||
if total_length <= max_num_tokens:
|
if total_length <= max_num_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||||
assert len(trunc_tokens) >= 1
|
assert len(trunc_tokens) >= 1
|
||||||
|
|
||||||
# We want to sometimes truncate from the front and sometimes from the
|
# We want to sometimes truncate from the front and sometimes from the
|
||||||
# back to add more randomness and avoid biases.
|
# back to add more randomness and avoid biases.
|
||||||
if rng.random() < 0.5:
|
if rng.random() < 0.5:
|
||||||
del trunc_tokens[0]
|
del trunc_tokens[0]
|
||||||
else:
|
else:
|
||||||
trunc_tokens.pop()
|
trunc_tokens.pop()
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
tokenizer = tokenization.FullTokenizer(
|
tokenizer = tokenization.FullTokenizer(
|
||||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||||
|
|
||||||
input_files = []
|
input_files = []
|
||||||
for input_pattern in FLAGS.input_file.split(","):
|
for input_pattern in FLAGS.input_file.split(","):
|
||||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||||
|
|
||||||
tf.logging.info("*** Reading from input files ***")
|
tf.logging.info("*** Reading from input files ***")
|
||||||
for input_file in input_files:
|
for input_file in input_files:
|
||||||
tf.logging.info(" %s", input_file)
|
tf.logging.info(" %s", input_file)
|
||||||
|
|
||||||
rng = random.Random(FLAGS.random_seed)
|
rng = random.Random(FLAGS.random_seed)
|
||||||
instances = create_training_instances(
|
instances = create_training_instances(
|
||||||
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||||
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||||
rng)
|
rng)
|
||||||
|
|
||||||
output_files = FLAGS.output_file.split(",")
|
output_files = FLAGS.output_file.split(",")
|
||||||
tf.logging.info("*** Writing to output files ***")
|
tf.logging.info("*** Writing to output files ***")
|
||||||
for output_file in output_files:
|
for output_file in output_files:
|
||||||
tf.logging.info(" %s", output_file)
|
tf.logging.info(" %s", output_file)
|
||||||
|
|
||||||
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||||
FLAGS.max_predictions_per_seq, output_files)
|
FLAGS.max_predictions_per_seq, output_files)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
flags.mark_flag_as_required("input_file")
|
flags.mark_flag_as_required("input_file")
|
||||||
flags.mark_flag_as_required("output_file")
|
flags.mark_flag_as_required("output_file")
|
||||||
flags.mark_flag_as_required("vocab_file")
|
flags.mark_flag_as_required("vocab_file")
|
||||||
tf.app.run()
|
tf.app.run()
|
||||||
|
|||||||
@@ -80,330 +80,330 @@ flags.DEFINE_bool(
|
|||||||
|
|
||||||
class InputExample(object):
|
class InputExample(object):
|
||||||
|
|
||||||
def __init__(self, unique_id, text_a, text_b):
|
def __init__(self, unique_id, text_a, text_b):
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.text_a = text_a
|
self.text_a = text_a
|
||||||
self.text_b = text_b
|
self.text_b = text_b
|
||||||
|
|
||||||
|
|
||||||
class InputFeatures(object):
|
class InputFeatures(object):
|
||||||
"""A single set of features of data."""
|
"""A single set of features of data."""
|
||||||
|
|
||||||
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.tokens = tokens
|
self.tokens = tokens
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.input_mask = input_mask
|
self.input_mask = input_mask
|
||||||
self.input_type_ids = input_type_ids
|
self.input_type_ids = input_type_ids
|
||||||
|
|
||||||
|
|
||||||
def input_fn_builder(features, seq_length):
|
def input_fn_builder(features, seq_length):
|
||||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||||
|
|
||||||
all_unique_ids = []
|
all_unique_ids = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_mask = []
|
all_input_mask = []
|
||||||
all_input_type_ids = []
|
all_input_type_ids = []
|
||||||
|
|
||||||
for feature in features:
|
for feature in features:
|
||||||
all_unique_ids.append(feature.unique_id)
|
all_unique_ids.append(feature.unique_id)
|
||||||
all_input_ids.append(feature.input_ids)
|
all_input_ids.append(feature.input_ids)
|
||||||
all_input_mask.append(feature.input_mask)
|
all_input_mask.append(feature.input_mask)
|
||||||
all_input_type_ids.append(feature.input_type_ids)
|
all_input_type_ids.append(feature.input_type_ids)
|
||||||
|
|
||||||
def input_fn(params):
|
def input_fn(params):
|
||||||
"""The actual input function."""
|
"""The actual input function."""
|
||||||
batch_size = params["batch_size"]
|
batch_size = params["batch_size"]
|
||||||
|
|
||||||
num_examples = len(features)
|
num_examples = len(features)
|
||||||
|
|
||||||
# This is for demo purposes and does NOT scale to large data sets. We do
|
# This is for demo purposes and does NOT scale to large data sets. We do
|
||||||
# not use Dataset.from_generator() because that uses tf.py_func which is
|
# not use Dataset.from_generator() because that uses tf.py_func which is
|
||||||
# not TPU compatible. The right way to load data is with TFRecordReader.
|
# not TPU compatible. The right way to load data is with TFRecordReader.
|
||||||
d = tf.data.Dataset.from_tensor_slices({
|
d = tf.data.Dataset.from_tensor_slices({
|
||||||
"unique_ids":
|
"unique_ids":
|
||||||
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
|
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
|
||||||
"input_ids":
|
"input_ids":
|
||||||
tf.constant(
|
tf.constant(
|
||||||
all_input_ids, shape=[num_examples, seq_length],
|
all_input_ids, shape=[num_examples, seq_length],
|
||||||
dtype=tf.int32),
|
dtype=tf.int32),
|
||||||
"input_mask":
|
"input_mask":
|
||||||
tf.constant(
|
tf.constant(
|
||||||
all_input_mask,
|
all_input_mask,
|
||||||
shape=[num_examples, seq_length],
|
shape=[num_examples, seq_length],
|
||||||
dtype=tf.int32),
|
dtype=tf.int32),
|
||||||
"input_type_ids":
|
"input_type_ids":
|
||||||
tf.constant(
|
tf.constant(
|
||||||
all_input_type_ids,
|
all_input_type_ids,
|
||||||
shape=[num_examples, seq_length],
|
shape=[num_examples, seq_length],
|
||||||
dtype=tf.int32),
|
dtype=tf.int32),
|
||||||
})
|
})
|
||||||
|
|
||||||
d = d.batch(batch_size=batch_size, drop_remainder=False)
|
d = d.batch(batch_size=batch_size, drop_remainder=False)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
|
||||||
def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
|
def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
|
||||||
use_one_hot_embeddings):
|
use_one_hot_embeddings):
|
||||||
"""Returns `model_fn` closure for TPUEstimator."""
|
"""Returns `model_fn` closure for TPUEstimator."""
|
||||||
|
|
||||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
"""The `model_fn` for TPUEstimator."""
|
"""The `model_fn` for TPUEstimator."""
|
||||||
|
|
||||||
unique_ids = features["unique_ids"]
|
unique_ids = features["unique_ids"]
|
||||||
input_ids = features["input_ids"]
|
input_ids = features["input_ids"]
|
||||||
input_mask = features["input_mask"]
|
input_mask = features["input_mask"]
|
||||||
input_type_ids = features["input_type_ids"]
|
input_type_ids = features["input_type_ids"]
|
||||||
|
|
||||||
model = modeling.BertModel(
|
model = modeling.BertModel(
|
||||||
config=bert_config,
|
config=bert_config,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
token_type_ids=input_type_ids,
|
token_type_ids=input_type_ids,
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
if mode != tf.estimator.ModeKeys.PREDICT:
|
if mode != tf.estimator.ModeKeys.PREDICT:
|
||||||
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
|
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
|
||||||
|
|
||||||
tvars = tf.trainable_variables()
|
tvars = tf.trainable_variables()
|
||||||
scaffold_fn = None
|
scaffold_fn = None
|
||||||
(assignment_map, _) = modeling.get_assigment_map_from_checkpoint(
|
(assignment_map, _) = modeling.get_assigment_map_from_checkpoint(
|
||||||
tvars, init_checkpoint)
|
tvars, init_checkpoint)
|
||||||
if use_tpu:
|
if use_tpu:
|
||||||
|
|
||||||
def tpu_scaffold():
|
def tpu_scaffold():
|
||||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
return tf.train.Scaffold()
|
return tf.train.Scaffold()
|
||||||
|
|
||||||
scaffold_fn = tpu_scaffold
|
scaffold_fn = tpu_scaffold
|
||||||
else:
|
else:
|
||||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
|
|
||||||
all_layers = model.get_all_encoder_layers()
|
all_layers = model.get_all_encoder_layers()
|
||||||
|
|
||||||
predictions = {
|
predictions = {
|
||||||
"unique_id": unique_ids,
|
"unique_id": unique_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i, layer_index) in enumerate(layer_indexes):
|
for (i, layer_index) in enumerate(layer_indexes):
|
||||||
predictions["layer_output_%d" % i] = all_layers[layer_index]
|
predictions["layer_output_%d" % i] = all_layers[layer_index]
|
||||||
|
|
||||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
||||||
return output_spec
|
return output_spec
|
||||||
|
|
||||||
return model_fn
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, seq_length, tokenizer):
|
def convert_examples_to_features(examples, seq_length, tokenizer):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (ex_index, example) in enumerate(examples):
|
for (ex_index, example) in enumerate(examples):
|
||||||
tokens_a = tokenizer.tokenize(example.text_a)
|
tokens_a = tokenizer.tokenize(example.text_a)
|
||||||
|
|
||||||
tokens_b = None
|
tokens_b = None
|
||||||
if example.text_b:
|
if example.text_b:
|
||||||
tokens_b = tokenizer.tokenize(example.text_b)
|
tokens_b = tokenizer.tokenize(example.text_b)
|
||||||
|
|
||||||
if tokens_b:
|
if tokens_b:
|
||||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||||
# length is less than the specified length.
|
# length is less than the specified length.
|
||||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||||
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
||||||
else:
|
else:
|
||||||
# Account for [CLS] and [SEP] with "- 2"
|
# Account for [CLS] and [SEP] with "- 2"
|
||||||
if len(tokens_a) > seq_length - 2:
|
if len(tokens_a) > seq_length - 2:
|
||||||
tokens_a = tokens_a[0:(seq_length - 2)]
|
tokens_a = tokens_a[0:(seq_length - 2)]
|
||||||
|
|
||||||
# The convention in BERT is:
|
# The convention in BERT is:
|
||||||
# (a) For sequence pairs:
|
# (a) For sequence pairs:
|
||||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||||
# (b) For single sequences:
|
# (b) For single sequences:
|
||||||
# tokens: [CLS] the dog is hairy . [SEP]
|
# tokens: [CLS] the dog is hairy . [SEP]
|
||||||
# type_ids: 0 0 0 0 0 0 0
|
# type_ids: 0 0 0 0 0 0 0
|
||||||
#
|
#
|
||||||
# Where "type_ids" are used to indicate whether this is the first
|
# Where "type_ids" are used to indicate whether this is the first
|
||||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||||
# embedding vector (and position vector). This is not *strictly* necessary
|
# embedding vector (and position vector). This is not *strictly* necessary
|
||||||
# since the [SEP] token unambigiously separates the sequences, but it makes
|
# since the [SEP] token unambigiously separates the sequences, but it makes
|
||||||
# it easier for the model to learn the concept of sequences.
|
# it easier for the model to learn the concept of sequences.
|
||||||
#
|
#
|
||||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||||
# used as as the "sentence vector". Note that this only makes sense because
|
# used as as the "sentence vector". Note that this only makes sense because
|
||||||
# the entire model is fine-tuned.
|
# the entire model is fine-tuned.
|
||||||
tokens = []
|
tokens = []
|
||||||
input_type_ids = []
|
input_type_ids = []
|
||||||
tokens.append("[CLS]")
|
tokens.append("[CLS]")
|
||||||
input_type_ids.append(0)
|
input_type_ids.append(0)
|
||||||
for token in tokens_a:
|
for token in tokens_a:
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
input_type_ids.append(0)
|
input_type_ids.append(0)
|
||||||
tokens.append("[SEP]")
|
tokens.append("[SEP]")
|
||||||
input_type_ids.append(0)
|
input_type_ids.append(0)
|
||||||
|
|
||||||
if tokens_b:
|
if tokens_b:
|
||||||
for token in tokens_b:
|
for token in tokens_b:
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
input_type_ids.append(1)
|
input_type_ids.append(1)
|
||||||
tokens.append("[SEP]")
|
tokens.append("[SEP]")
|
||||||
input_type_ids.append(1)
|
input_type_ids.append(1)
|
||||||
|
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||||
# tokens are attended to.
|
# tokens are attended to.
|
||||||
input_mask = [1] * len(input_ids)
|
input_mask = [1] * len(input_ids)
|
||||||
|
|
||||||
# Zero-pad up to the sequence length.
|
# Zero-pad up to the sequence length.
|
||||||
while len(input_ids) < seq_length:
|
while len(input_ids) < seq_length:
|
||||||
input_ids.append(0)
|
input_ids.append(0)
|
||||||
input_mask.append(0)
|
input_mask.append(0)
|
||||||
input_type_ids.append(0)
|
input_type_ids.append(0)
|
||||||
|
|
||||||
assert len(input_ids) == seq_length
|
assert len(input_ids) == seq_length
|
||||||
assert len(input_mask) == seq_length
|
assert len(input_mask) == seq_length
|
||||||
assert len(input_type_ids) == seq_length
|
assert len(input_type_ids) == seq_length
|
||||||
|
|
||||||
if ex_index < 5:
|
if ex_index < 5:
|
||||||
tf.logging.info("*** Example ***")
|
tf.logging.info("*** Example ***")
|
||||||
tf.logging.info("unique_id: %s" % (example.unique_id))
|
tf.logging.info("unique_id: %s" % (example.unique_id))
|
||||||
tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
|
tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
|
||||||
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||||
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
tf.logging.info(
|
tf.logging.info(
|
||||||
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
|
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
|
||||||
|
|
||||||
features.append(
|
features.append(
|
||||||
InputFeatures(
|
InputFeatures(
|
||||||
unique_id=example.unique_id,
|
unique_id=example.unique_id,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
input_type_ids=input_type_ids))
|
input_type_ids=input_type_ids))
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
"""Truncates a sequence pair in place to the maximum length."""
|
"""Truncates a sequence pair in place to the maximum length."""
|
||||||
|
|
||||||
# This is a simple heuristic which will always truncate the longer sequence
|
# This is a simple heuristic which will always truncate the longer sequence
|
||||||
# one token at a time. This makes more sense than truncating an equal percent
|
# one token at a time. This makes more sense than truncating an equal percent
|
||||||
# of tokens from each, since if one sequence is very short then each token
|
# of tokens from each, since if one sequence is very short then each token
|
||||||
# that's truncated likely contains more information than a longer sequence.
|
# that's truncated likely contains more information than a longer sequence.
|
||||||
while True:
|
while True:
|
||||||
total_length = len(tokens_a) + len(tokens_b)
|
total_length = len(tokens_a) + len(tokens_b)
|
||||||
if total_length <= max_length:
|
if total_length <= max_length:
|
||||||
break
|
break
|
||||||
if len(tokens_a) > len(tokens_b):
|
if len(tokens_a) > len(tokens_b):
|
||||||
tokens_a.pop()
|
tokens_a.pop()
|
||||||
else:
|
else:
|
||||||
tokens_b.pop()
|
tokens_b.pop()
|
||||||
|
|
||||||
|
|
||||||
def read_examples(input_file):
|
def read_examples(input_file):
|
||||||
"""Read a list of `InputExample`s from an input file."""
|
"""Read a list of `InputExample`s from an input file."""
|
||||||
examples = []
|
examples = []
|
||||||
unique_id = 0
|
unique_id = 0
|
||||||
with tf.gfile.GFile(input_file, "r") as reader:
|
with tf.gfile.GFile(input_file, "r") as reader:
|
||||||
while True:
|
while True:
|
||||||
line = tokenization.convert_to_unicode(reader.readline())
|
line = tokenization.convert_to_unicode(reader.readline())
|
||||||
if not line:
|
if not line:
|
||||||
break
|
break
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
text_a = None
|
text_a = None
|
||||||
text_b = None
|
text_b = None
|
||||||
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
||||||
if m is None:
|
if m is None:
|
||||||
text_a = line
|
text_a = line
|
||||||
else:
|
else:
|
||||||
text_a = m.group(1)
|
text_a = m.group(1)
|
||||||
text_b = m.group(2)
|
text_b = m.group(2)
|
||||||
examples.append(
|
examples.append(
|
||||||
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
|
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
|
||||||
|
|
||||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||||
|
|
||||||
tokenizer = tokenization.FullTokenizer(
|
tokenizer = tokenization.FullTokenizer(
|
||||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||||
|
|
||||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||||
run_config = tf.contrib.tpu.RunConfig(
|
run_config = tf.contrib.tpu.RunConfig(
|
||||||
master=FLAGS.master,
|
master=FLAGS.master,
|
||||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||||
num_shards=FLAGS.num_tpu_cores,
|
num_shards=FLAGS.num_tpu_cores,
|
||||||
per_host_input_for_training=is_per_host))
|
per_host_input_for_training=is_per_host))
|
||||||
|
|
||||||
examples = read_examples(FLAGS.input_file)
|
examples = read_examples(FLAGS.input_file)
|
||||||
|
|
||||||
features = convert_examples_to_features(
|
features = convert_examples_to_features(
|
||||||
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
|
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
|
||||||
|
|
||||||
unique_id_to_feature = {}
|
unique_id_to_feature = {}
|
||||||
for feature in features:
|
for feature in features:
|
||||||
unique_id_to_feature[feature.unique_id] = feature
|
unique_id_to_feature[feature.unique_id] = feature
|
||||||
|
|
||||||
model_fn = model_fn_builder(
|
model_fn = model_fn_builder(
|
||||||
bert_config=bert_config,
|
bert_config=bert_config,
|
||||||
init_checkpoint=FLAGS.init_checkpoint,
|
init_checkpoint=FLAGS.init_checkpoint,
|
||||||
layer_indexes=layer_indexes,
|
layer_indexes=layer_indexes,
|
||||||
use_tpu=FLAGS.use_tpu,
|
use_tpu=FLAGS.use_tpu,
|
||||||
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
|
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
|
||||||
|
|
||||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
# or GPU.
|
# or GPU.
|
||||||
estimator = tf.contrib.tpu.TPUEstimator(
|
estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
use_tpu=FLAGS.use_tpu,
|
use_tpu=FLAGS.use_tpu,
|
||||||
model_fn=model_fn,
|
model_fn=model_fn,
|
||||||
config=run_config,
|
config=run_config,
|
||||||
predict_batch_size=FLAGS.batch_size)
|
predict_batch_size=FLAGS.batch_size)
|
||||||
|
|
||||||
input_fn = input_fn_builder(
|
input_fn = input_fn_builder(
|
||||||
features=features, seq_length=FLAGS.max_seq_length)
|
features=features, seq_length=FLAGS.max_seq_length)
|
||||||
|
|
||||||
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
|
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
|
||||||
"w")) as writer:
|
"w")) as writer:
|
||||||
for result in estimator.predict(input_fn, yield_single_examples=True):
|
for result in estimator.predict(input_fn, yield_single_examples=True):
|
||||||
unique_id = int(result["unique_id"])
|
unique_id = int(result["unique_id"])
|
||||||
feature = unique_id_to_feature[unique_id]
|
feature = unique_id_to_feature[unique_id]
|
||||||
output_json = collections.OrderedDict()
|
output_json = collections.OrderedDict()
|
||||||
output_json["linex_index"] = unique_id
|
output_json["linex_index"] = unique_id
|
||||||
all_features = []
|
all_features = []
|
||||||
for (i, token) in enumerate(feature.tokens):
|
for (i, token) in enumerate(feature.tokens):
|
||||||
all_layers = []
|
all_layers = []
|
||||||
for (j, layer_index) in enumerate(layer_indexes):
|
for (j, layer_index) in enumerate(layer_indexes):
|
||||||
layer_output = result["layer_output_%d" % j]
|
layer_output = result["layer_output_%d" % j]
|
||||||
layers = collections.OrderedDict()
|
layers = collections.OrderedDict()
|
||||||
layers["index"] = layer_index
|
layers["index"] = layer_index
|
||||||
layers["values"] = [
|
layers["values"] = [
|
||||||
round(float(x), 6) for x in layer_output[i:(i + 1)].flat
|
round(float(x), 6) for x in layer_output[i:(i + 1)].flat
|
||||||
]
|
]
|
||||||
all_layers.append(layers)
|
all_layers.append(layers)
|
||||||
features = collections.OrderedDict()
|
features = collections.OrderedDict()
|
||||||
features["token"] = token
|
features["token"] = token
|
||||||
features["layers"] = all_layers
|
features["layers"] = all_layers
|
||||||
all_features.append(features)
|
all_features.append(features)
|
||||||
output_json["features"] = all_features
|
output_json["features"] = all_features
|
||||||
writer.write(json.dumps(output_json) + "\n")
|
writer.write(json.dumps(output_json) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
flags.mark_flag_as_required("input_file")
|
flags.mark_flag_as_required("input_file")
|
||||||
flags.mark_flag_as_required("vocab_file")
|
flags.mark_flag_as_required("vocab_file")
|
||||||
flags.mark_flag_as_required("bert_config_file")
|
flags.mark_flag_as_required("bert_config_file")
|
||||||
flags.mark_flag_as_required("init_checkpoint")
|
flags.mark_flag_as_required("init_checkpoint")
|
||||||
flags.mark_flag_as_required("output_file")
|
flags.mark_flag_as_required("output_file")
|
||||||
tf.app.run()
|
tf.app.run()
|
||||||
|
|||||||
1496
modeling.py
1496
modeling.py
File diff suppressed because it is too large
Load Diff
421
modeling_test.py
421
modeling_test.py
@@ -27,250 +27,249 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
class BertModelTest(tf.test.TestCase):
|
class BertModelTest(tf.test.TestCase):
|
||||||
|
class BertModelTester(object):
|
||||||
|
|
||||||
class BertModelTester(object):
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
initializer_range=0.02,
|
||||||
|
scope=None):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
def __init__(self,
|
def create_model(self):
|
||||||
parent,
|
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
|
||||||
batch_size=13,
|
self.vocab_size)
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=5,
|
|
||||||
num_attention_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
initializer_range=0.02,
|
|
||||||
scope=None):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
def create_model(self):
|
input_mask = None
|
||||||
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
|
if self.use_input_mask:
|
||||||
self.vocab_size)
|
input_mask = BertModelTest.ids_tensor(
|
||||||
|
[self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
input_mask = None
|
token_type_ids = None
|
||||||
if self.use_input_mask:
|
if self.use_token_type_ids:
|
||||||
input_mask = BertModelTest.ids_tensor(
|
token_type_ids = BertModelTest.ids_tensor(
|
||||||
[self.batch_size, self.seq_length], vocab_size=2)
|
[self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
token_type_ids = None
|
config = modeling.BertConfig(
|
||||||
if self.use_token_type_ids:
|
vocab_size=self.vocab_size,
|
||||||
token_type_ids = BertModelTest.ids_tensor(
|
hidden_size=self.hidden_size,
|
||||||
[self.batch_size, self.seq_length], self.type_vocab_size)
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
config = modeling.BertConfig(
|
model = modeling.BertModel(
|
||||||
vocab_size=self.vocab_size,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
is_training=self.is_training,
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
input_ids=input_ids,
|
||||||
num_attention_heads=self.num_attention_heads,
|
input_mask=input_mask,
|
||||||
intermediate_size=self.intermediate_size,
|
token_type_ids=token_type_ids,
|
||||||
hidden_act=self.hidden_act,
|
scope=self.scope)
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
initializer_range=self.initializer_range)
|
|
||||||
|
|
||||||
model = modeling.BertModel(
|
outputs = {
|
||||||
config=config,
|
"embedding_output": model.get_embedding_output(),
|
||||||
is_training=self.is_training,
|
"sequence_output": model.get_sequence_output(),
|
||||||
input_ids=input_ids,
|
"pooled_output": model.get_pooled_output(),
|
||||||
input_mask=input_mask,
|
"all_encoder_layers": model.get_all_encoder_layers(),
|
||||||
token_type_ids=token_type_ids,
|
}
|
||||||
scope=self.scope)
|
return outputs
|
||||||
|
|
||||||
outputs = {
|
def check_output(self, result):
|
||||||
"embedding_output": model.get_embedding_output(),
|
self.parent.assertAllEqual(
|
||||||
"sequence_output": model.get_sequence_output(),
|
result["embedding_output"].shape,
|
||||||
"pooled_output": model.get_pooled_output(),
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
"all_encoder_layers": model.get_all_encoder_layers(),
|
|
||||||
}
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def check_output(self, result):
|
self.parent.assertAllEqual(
|
||||||
self.parent.assertAllEqual(
|
result["sequence_output"].shape,
|
||||||
result["embedding_output"].shape,
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
|
||||||
|
|
||||||
self.parent.assertAllEqual(
|
self.parent.assertAllEqual(result["pooled_output"].shape,
|
||||||
result["sequence_output"].shape,
|
[self.batch_size, self.hidden_size])
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
|
||||||
|
|
||||||
self.parent.assertAllEqual(result["pooled_output"].shape,
|
def test_default(self):
|
||||||
[self.batch_size, self.hidden_size])
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
|
|
||||||
def test_default(self):
|
def test_config_to_json_string(self):
|
||||||
self.run_tester(BertModelTest.BertModelTester(self))
|
config = modeling.BertConfig(vocab_size=99, hidden_size=37)
|
||||||
|
obj = json.loads(config.to_json_string())
|
||||||
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
|
self.assertEqual(obj["hidden_size"], 37)
|
||||||
|
|
||||||
def test_config_to_json_string(self):
|
def run_tester(self, tester):
|
||||||
config = modeling.BertConfig(vocab_size=99, hidden_size=37)
|
with self.test_session() as sess:
|
||||||
obj = json.loads(config.to_json_string())
|
ops = tester.create_model()
|
||||||
self.assertEqual(obj["vocab_size"], 99)
|
init_op = tf.group(tf.global_variables_initializer(),
|
||||||
self.assertEqual(obj["hidden_size"], 37)
|
tf.local_variables_initializer())
|
||||||
|
sess.run(init_op)
|
||||||
|
output_result = sess.run(ops)
|
||||||
|
tester.check_output(output_result)
|
||||||
|
|
||||||
def run_tester(self, tester):
|
self.assert_all_tensors_reachable(sess, [init_op, ops])
|
||||||
with self.test_session() as sess:
|
|
||||||
ops = tester.create_model()
|
|
||||||
init_op = tf.group(tf.global_variables_initializer(),
|
|
||||||
tf.local_variables_initializer())
|
|
||||||
sess.run(init_op)
|
|
||||||
output_result = sess.run(ops)
|
|
||||||
tester.check_output(output_result)
|
|
||||||
|
|
||||||
self.assert_all_tensors_reachable(sess, [init_op, ops])
|
@classmethod
|
||||||
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
|
if rng is None:
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
@classmethod
|
total_dims = 1
|
||||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
for dim in shape:
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
total_dims *= dim
|
||||||
if rng is None:
|
|
||||||
rng = random.Random()
|
|
||||||
|
|
||||||
total_dims = 1
|
values = []
|
||||||
for dim in shape:
|
for _ in range(total_dims):
|
||||||
total_dims *= dim
|
values.append(rng.randint(0, vocab_size - 1))
|
||||||
|
|
||||||
values = []
|
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
|
||||||
for _ in range(total_dims):
|
|
||||||
values.append(rng.randint(0, vocab_size - 1))
|
|
||||||
|
|
||||||
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
|
def assert_all_tensors_reachable(self, sess, outputs):
|
||||||
|
"""Checks that all the tensors in the graph are reachable from outputs."""
|
||||||
|
graph = sess.graph
|
||||||
|
|
||||||
def assert_all_tensors_reachable(self, sess, outputs):
|
ignore_strings = [
|
||||||
"""Checks that all the tensors in the graph are reachable from outputs."""
|
"^.*/dilation_rate$",
|
||||||
graph = sess.graph
|
"^.*/Tensordot/concat$",
|
||||||
|
"^.*/Tensordot/concat/axis$",
|
||||||
|
"^testing/.*$",
|
||||||
|
]
|
||||||
|
|
||||||
ignore_strings = [
|
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
||||||
"^.*/dilation_rate$",
|
|
||||||
"^.*/Tensordot/concat$",
|
|
||||||
"^.*/Tensordot/concat/axis$",
|
|
||||||
"^testing/.*$",
|
|
||||||
]
|
|
||||||
|
|
||||||
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
unreachable = self.get_unreachable_ops(graph, outputs)
|
||||||
|
filtered_unreachable = []
|
||||||
|
for x in unreachable:
|
||||||
|
do_ignore = False
|
||||||
|
for r in ignore_regexes:
|
||||||
|
m = r.match(x.name)
|
||||||
|
if m is not None:
|
||||||
|
do_ignore = True
|
||||||
|
if do_ignore:
|
||||||
|
continue
|
||||||
|
filtered_unreachable.append(x)
|
||||||
|
unreachable = filtered_unreachable
|
||||||
|
|
||||||
unreachable = self.get_unreachable_ops(graph, outputs)
|
self.assertEqual(
|
||||||
filtered_unreachable = []
|
len(unreachable), 0, "The following ops are unreachable: %s" %
|
||||||
for x in unreachable:
|
(" ".join([x.name for x in unreachable])))
|
||||||
do_ignore = False
|
|
||||||
for r in ignore_regexes:
|
|
||||||
m = r.match(x.name)
|
|
||||||
if m is not None:
|
|
||||||
do_ignore = True
|
|
||||||
if do_ignore:
|
|
||||||
continue
|
|
||||||
filtered_unreachable.append(x)
|
|
||||||
unreachable = filtered_unreachable
|
|
||||||
|
|
||||||
self.assertEqual(
|
@classmethod
|
||||||
len(unreachable), 0, "The following ops are unreachable: %s" %
|
def get_unreachable_ops(cls, graph, outputs):
|
||||||
(" ".join([x.name for x in unreachable])))
|
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
||||||
|
outputs = cls.flatten_recursive(outputs)
|
||||||
|
output_to_op = collections.defaultdict(list)
|
||||||
|
op_to_all = collections.defaultdict(list)
|
||||||
|
assign_out_to_in = collections.defaultdict(list)
|
||||||
|
|
||||||
@classmethod
|
for op in graph.get_operations():
|
||||||
def get_unreachable_ops(cls, graph, outputs):
|
for x in op.inputs:
|
||||||
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
op_to_all[op.name].append(x.name)
|
||||||
outputs = cls.flatten_recursive(outputs)
|
for y in op.outputs:
|
||||||
output_to_op = collections.defaultdict(list)
|
output_to_op[y.name].append(op.name)
|
||||||
op_to_all = collections.defaultdict(list)
|
op_to_all[op.name].append(y.name)
|
||||||
assign_out_to_in = collections.defaultdict(list)
|
if str(op.type) == "Assign":
|
||||||
|
for y in op.outputs:
|
||||||
|
for x in op.inputs:
|
||||||
|
assign_out_to_in[y.name].append(x.name)
|
||||||
|
|
||||||
for op in graph.get_operations():
|
assign_groups = collections.defaultdict(list)
|
||||||
for x in op.inputs:
|
for out_name in assign_out_to_in.keys():
|
||||||
op_to_all[op.name].append(x.name)
|
name_group = assign_out_to_in[out_name]
|
||||||
for y in op.outputs:
|
for n1 in name_group:
|
||||||
output_to_op[y.name].append(op.name)
|
assign_groups[n1].append(out_name)
|
||||||
op_to_all[op.name].append(y.name)
|
for n2 in name_group:
|
||||||
if str(op.type) == "Assign":
|
if n1 != n2:
|
||||||
for y in op.outputs:
|
assign_groups[n1].append(n2)
|
||||||
for x in op.inputs:
|
|
||||||
assign_out_to_in[y.name].append(x.name)
|
|
||||||
|
|
||||||
assign_groups = collections.defaultdict(list)
|
seen_tensors = {}
|
||||||
for out_name in assign_out_to_in.keys():
|
stack = [x.name for x in outputs]
|
||||||
name_group = assign_out_to_in[out_name]
|
while stack:
|
||||||
for n1 in name_group:
|
name = stack.pop()
|
||||||
assign_groups[n1].append(out_name)
|
if name in seen_tensors:
|
||||||
for n2 in name_group:
|
continue
|
||||||
if n1 != n2:
|
seen_tensors[name] = True
|
||||||
assign_groups[n1].append(n2)
|
|
||||||
|
|
||||||
seen_tensors = {}
|
if name in output_to_op:
|
||||||
stack = [x.name for x in outputs]
|
for op_name in output_to_op[name]:
|
||||||
while stack:
|
if op_name in op_to_all:
|
||||||
name = stack.pop()
|
for input_name in op_to_all[op_name]:
|
||||||
if name in seen_tensors:
|
if input_name not in stack:
|
||||||
continue
|
stack.append(input_name)
|
||||||
seen_tensors[name] = True
|
|
||||||
|
|
||||||
if name in output_to_op:
|
expanded_names = []
|
||||||
for op_name in output_to_op[name]:
|
if name in assign_groups:
|
||||||
if op_name in op_to_all:
|
for assign_name in assign_groups[name]:
|
||||||
for input_name in op_to_all[op_name]:
|
expanded_names.append(assign_name)
|
||||||
if input_name not in stack:
|
|
||||||
stack.append(input_name)
|
|
||||||
|
|
||||||
expanded_names = []
|
for expanded_name in expanded_names:
|
||||||
if name in assign_groups:
|
if expanded_name not in stack:
|
||||||
for assign_name in assign_groups[name]:
|
stack.append(expanded_name)
|
||||||
expanded_names.append(assign_name)
|
|
||||||
|
|
||||||
for expanded_name in expanded_names:
|
unreachable_ops = []
|
||||||
if expanded_name not in stack:
|
for op in graph.get_operations():
|
||||||
stack.append(expanded_name)
|
is_unreachable = False
|
||||||
|
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
||||||
|
for name in all_names:
|
||||||
|
if name not in seen_tensors:
|
||||||
|
is_unreachable = True
|
||||||
|
if is_unreachable:
|
||||||
|
unreachable_ops.append(op)
|
||||||
|
return unreachable_ops
|
||||||
|
|
||||||
unreachable_ops = []
|
@classmethod
|
||||||
for op in graph.get_operations():
|
def flatten_recursive(cls, item):
|
||||||
is_unreachable = False
|
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
||||||
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
output = []
|
||||||
for name in all_names:
|
if isinstance(item, list):
|
||||||
if name not in seen_tensors:
|
output.extend(item)
|
||||||
is_unreachable = True
|
elif isinstance(item, tuple):
|
||||||
if is_unreachable:
|
output.extend(list(item))
|
||||||
unreachable_ops.append(op)
|
elif isinstance(item, dict):
|
||||||
return unreachable_ops
|
for (_, v) in six.iteritems(item):
|
||||||
|
output.append(v)
|
||||||
|
else:
|
||||||
|
return [item]
|
||||||
|
|
||||||
@classmethod
|
flat_output = []
|
||||||
def flatten_recursive(cls, item):
|
for x in output:
|
||||||
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
flat_output.extend(cls.flatten_recursive(x))
|
||||||
output = []
|
return flat_output
|
||||||
if isinstance(item, list):
|
|
||||||
output.extend(item)
|
|
||||||
elif isinstance(item, tuple):
|
|
||||||
output.extend(list(item))
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
for (_, v) in six.iteritems(item):
|
|
||||||
output.append(v)
|
|
||||||
else:
|
|
||||||
return [item]
|
|
||||||
|
|
||||||
flat_output = []
|
|
||||||
for x in output:
|
|
||||||
flat_output.extend(cls.flatten_recursive(x))
|
|
||||||
return flat_output
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|||||||
236
optimization.py
236
optimization.py
@@ -23,149 +23,149 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
||||||
"""Creates an optimizer training op."""
|
"""Creates an optimizer training op."""
|
||||||
global_step = tf.train.get_or_create_global_step()
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
|
||||||
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
||||||
|
|
||||||
# Implements linear decay of the learning rate.
|
# Implements linear decay of the learning rate.
|
||||||
learning_rate = tf.train.polynomial_decay(
|
learning_rate = tf.train.polynomial_decay(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
global_step,
|
global_step,
|
||||||
num_train_steps,
|
num_train_steps,
|
||||||
end_learning_rate=0.0,
|
end_learning_rate=0.0,
|
||||||
power=1.0,
|
power=1.0,
|
||||||
cycle=False)
|
cycle=False)
|
||||||
|
|
||||||
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
|
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
|
||||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||||
if num_warmup_steps:
|
if num_warmup_steps:
|
||||||
global_steps_int = tf.cast(global_step, tf.int32)
|
global_steps_int = tf.cast(global_step, tf.int32)
|
||||||
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
||||||
|
|
||||||
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
||||||
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
||||||
|
|
||||||
warmup_percent_done = global_steps_float / warmup_steps_float
|
warmup_percent_done = global_steps_float / warmup_steps_float
|
||||||
warmup_learning_rate = init_lr * warmup_percent_done
|
warmup_learning_rate = init_lr * warmup_percent_done
|
||||||
|
|
||||||
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
||||||
learning_rate = (
|
learning_rate = (
|
||||||
(1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
|
(1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
|
||||||
|
|
||||||
# It is recommended that you use this optimizer for fine tuning, since this
|
# It is recommended that you use this optimizer for fine tuning, since this
|
||||||
# is how the model was trained (note that the Adam m/v variables are NOT
|
# is how the model was trained (note that the Adam m/v variables are NOT
|
||||||
# loaded from init_checkpoint.)
|
# loaded from init_checkpoint.)
|
||||||
optimizer = AdamWeightDecayOptimizer(
|
optimizer = AdamWeightDecayOptimizer(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
weight_decay_rate=0.01,
|
weight_decay_rate=0.01,
|
||||||
beta_1=0.9,
|
beta_1=0.9,
|
||||||
beta_2=0.999,
|
beta_2=0.999,
|
||||||
epsilon=1e-6,
|
epsilon=1e-6,
|
||||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
||||||
|
|
||||||
if use_tpu:
|
if use_tpu:
|
||||||
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
||||||
|
|
||||||
tvars = tf.trainable_variables()
|
tvars = tf.trainable_variables()
|
||||||
grads = tf.gradients(loss, tvars)
|
grads = tf.gradients(loss, tvars)
|
||||||
|
|
||||||
# This is how the model was pre-trained.
|
# This is how the model was pre-trained.
|
||||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
||||||
|
|
||||||
train_op = optimizer.apply_gradients(
|
train_op = optimizer.apply_gradients(
|
||||||
zip(grads, tvars), global_step=global_step)
|
zip(grads, tvars), global_step=global_step)
|
||||||
|
|
||||||
new_global_step = global_step + 1
|
new_global_step = global_step + 1
|
||||||
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
||||||
return train_op
|
return train_op
|
||||||
|
|
||||||
|
|
||||||
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
||||||
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
weight_decay_rate=0.0,
|
weight_decay_rate=0.0,
|
||||||
beta_1=0.9,
|
beta_1=0.9,
|
||||||
beta_2=0.999,
|
beta_2=0.999,
|
||||||
epsilon=1e-6,
|
epsilon=1e-6,
|
||||||
exclude_from_weight_decay=None,
|
exclude_from_weight_decay=None,
|
||||||
name="AdamWeightDecayOptimizer"):
|
name="AdamWeightDecayOptimizer"):
|
||||||
"""Constructs a AdamWeightDecayOptimizer."""
|
"""Constructs a AdamWeightDecayOptimizer."""
|
||||||
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.weight_decay_rate = weight_decay_rate
|
self.weight_decay_rate = weight_decay_rate
|
||||||
self.beta_1 = beta_1
|
self.beta_1 = beta_1
|
||||||
self.beta_2 = beta_2
|
self.beta_2 = beta_2
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.exclude_from_weight_decay = exclude_from_weight_decay
|
self.exclude_from_weight_decay = exclude_from_weight_decay
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
assignments = []
|
assignments = []
|
||||||
for (grad, param) in grads_and_vars:
|
for (grad, param) in grads_and_vars:
|
||||||
if grad is None or param is None:
|
if grad is None or param is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param_name = self._get_variable_name(param.name)
|
param_name = self._get_variable_name(param.name)
|
||||||
|
|
||||||
m = tf.get_variable(
|
m = tf.get_variable(
|
||||||
name=param_name + "/adam_m",
|
name=param_name + "/adam_m",
|
||||||
shape=param.shape.as_list(),
|
shape=param.shape.as_list(),
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
trainable=False,
|
trainable=False,
|
||||||
initializer=tf.zeros_initializer())
|
initializer=tf.zeros_initializer())
|
||||||
v = tf.get_variable(
|
v = tf.get_variable(
|
||||||
name=param_name + "/adam_v",
|
name=param_name + "/adam_v",
|
||||||
shape=param.shape.as_list(),
|
shape=param.shape.as_list(),
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
trainable=False,
|
trainable=False,
|
||||||
initializer=tf.zeros_initializer())
|
initializer=tf.zeros_initializer())
|
||||||
|
|
||||||
# Standard Adam update.
|
# Standard Adam update.
|
||||||
next_m = (
|
next_m = (
|
||||||
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
||||||
next_v = (
|
next_v = (
|
||||||
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
||||||
tf.square(grad)))
|
tf.square(grad)))
|
||||||
|
|
||||||
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
||||||
|
|
||||||
# Just adding the square of the weights to the loss function is *not*
|
# Just adding the square of the weights to the loss function is *not*
|
||||||
# the correct way of using L2 regularization/weight decay with Adam,
|
# the correct way of using L2 regularization/weight decay with Adam,
|
||||||
# since that will interact with the m and v parameters in strange ways.
|
# since that will interact with the m and v parameters in strange ways.
|
||||||
#
|
#
|
||||||
# Instead we want ot decay the weights in a manner that doesn't interact
|
# Instead we want ot decay the weights in a manner that doesn't interact
|
||||||
# with the m/v parameters. This is equivalent to adding the square
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
# of the weights to the loss with plain (non-momentum) SGD.
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
if self._do_use_weight_decay(param_name):
|
if self._do_use_weight_decay(param_name):
|
||||||
update += self.weight_decay_rate * param
|
update += self.weight_decay_rate * param
|
||||||
|
|
||||||
update_with_lr = self.learning_rate * update
|
update_with_lr = self.learning_rate * update
|
||||||
|
|
||||||
next_param = param - update_with_lr
|
next_param = param - update_with_lr
|
||||||
|
|
||||||
assignments.extend(
|
assignments.extend(
|
||||||
[param.assign(next_param),
|
[param.assign(next_param),
|
||||||
m.assign(next_m),
|
m.assign(next_m),
|
||||||
v.assign(next_v)])
|
v.assign(next_v)])
|
||||||
return tf.group(*assignments, name=name)
|
return tf.group(*assignments, name=name)
|
||||||
|
|
||||||
def _do_use_weight_decay(self, param_name):
|
def _do_use_weight_decay(self, param_name):
|
||||||
"""Whether to use L2 weight decay for `param_name`."""
|
"""Whether to use L2 weight decay for `param_name`."""
|
||||||
if not self.weight_decay_rate:
|
if not self.weight_decay_rate:
|
||||||
return False
|
return False
|
||||||
if self.exclude_from_weight_decay:
|
if self.exclude_from_weight_decay:
|
||||||
for r in self.exclude_from_weight_decay:
|
for r in self.exclude_from_weight_decay:
|
||||||
if re.search(r, param_name) is not None:
|
if re.search(r, param_name) is not None:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_variable_name(self, param_name):
|
def _get_variable_name(self, param_name):
|
||||||
"""Get the variable name from the tensor name."""
|
"""Get the variable name from the tensor name."""
|
||||||
m = re.match("^(.*):\\d+$", param_name)
|
m = re.match("^(.*):\\d+$", param_name)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
param_name = m.group(1)
|
param_name = m.group(1)
|
||||||
return param_name
|
return param_name
|
||||||
|
|||||||
@@ -22,27 +22,27 @@ import tensorflow as tf
|
|||||||
|
|
||||||
class OptimizationTest(tf.test.TestCase):
|
class OptimizationTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_adam(self):
|
def test_adam(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
w = tf.get_variable(
|
w = tf.get_variable(
|
||||||
"w",
|
"w",
|
||||||
shape=[3],
|
shape=[3],
|
||||||
initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
|
initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
|
||||||
x = tf.constant([0.4, 0.2, -0.5])
|
x = tf.constant([0.4, 0.2, -0.5])
|
||||||
loss = tf.reduce_mean(tf.square(x - w))
|
loss = tf.reduce_mean(tf.square(x - w))
|
||||||
tvars = tf.trainable_variables()
|
tvars = tf.trainable_variables()
|
||||||
grads = tf.gradients(loss, tvars)
|
grads = tf.gradients(loss, tvars)
|
||||||
global_step = tf.train.get_or_create_global_step()
|
global_step = tf.train.get_or_create_global_step()
|
||||||
optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
|
optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
|
||||||
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
|
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
|
||||||
init_op = tf.group(tf.global_variables_initializer(),
|
init_op = tf.group(tf.global_variables_initializer(),
|
||||||
tf.local_variables_initializer())
|
tf.local_variables_initializer())
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
sess.run(train_op)
|
sess.run(train_op)
|
||||||
w_np = sess.run(w)
|
w_np = sess.run(w)
|
||||||
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -109,217 +109,217 @@ flags.DEFINE_integer(
|
|||||||
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
||||||
num_train_steps, num_warmup_steps, use_tpu,
|
num_train_steps, num_warmup_steps, use_tpu,
|
||||||
use_one_hot_embeddings):
|
use_one_hot_embeddings):
|
||||||
"""Returns `model_fn` closure for TPUEstimator."""
|
"""Returns `model_fn` closure for TPUEstimator."""
|
||||||
|
|
||||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
"""The `model_fn` for TPUEstimator."""
|
"""The `model_fn` for TPUEstimator."""
|
||||||
|
|
||||||
tf.logging.info("*** Features ***")
|
tf.logging.info("*** Features ***")
|
||||||
for name in sorted(features.keys()):
|
for name in sorted(features.keys()):
|
||||||
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
||||||
|
|
||||||
input_ids = features["input_ids"]
|
input_ids = features["input_ids"]
|
||||||
input_mask = features["input_mask"]
|
input_mask = features["input_mask"]
|
||||||
segment_ids = features["segment_ids"]
|
segment_ids = features["segment_ids"]
|
||||||
masked_lm_positions = features["masked_lm_positions"]
|
masked_lm_positions = features["masked_lm_positions"]
|
||||||
masked_lm_ids = features["masked_lm_ids"]
|
masked_lm_ids = features["masked_lm_ids"]
|
||||||
masked_lm_weights = features["masked_lm_weights"]
|
masked_lm_weights = features["masked_lm_weights"]
|
||||||
next_sentence_labels = features["next_sentence_labels"]
|
next_sentence_labels = features["next_sentence_labels"]
|
||||||
|
|
||||||
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||||
|
|
||||||
model = modeling.BertModel(
|
model = modeling.BertModel(
|
||||||
config=bert_config,
|
config=bert_config,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
token_type_ids=segment_ids,
|
token_type_ids=segment_ids,
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
(masked_lm_loss,
|
(masked_lm_loss,
|
||||||
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
||||||
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
||||||
masked_lm_positions, masked_lm_ids, masked_lm_weights)
|
masked_lm_positions, masked_lm_ids, masked_lm_weights)
|
||||||
|
|
||||||
(next_sentence_loss, next_sentence_example_loss,
|
(next_sentence_loss, next_sentence_example_loss,
|
||||||
next_sentence_log_probs) = get_next_sentence_output(
|
next_sentence_log_probs) = get_next_sentence_output(
|
||||||
bert_config, model.get_pooled_output(), next_sentence_labels)
|
bert_config, model.get_pooled_output(), next_sentence_labels)
|
||||||
|
|
||||||
total_loss = masked_lm_loss + next_sentence_loss
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
|
|
||||||
tvars = tf.trainable_variables()
|
tvars = tf.trainable_variables()
|
||||||
|
|
||||||
initialized_variable_names = {}
|
initialized_variable_names = {}
|
||||||
scaffold_fn = None
|
scaffold_fn = None
|
||||||
if init_checkpoint:
|
if init_checkpoint:
|
||||||
(assignment_map,
|
(assignment_map,
|
||||||
initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
|
initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
|
||||||
tvars, init_checkpoint)
|
tvars, init_checkpoint)
|
||||||
if use_tpu:
|
if use_tpu:
|
||||||
|
|
||||||
def tpu_scaffold():
|
def tpu_scaffold():
|
||||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
return tf.train.Scaffold()
|
return tf.train.Scaffold()
|
||||||
|
|
||||||
scaffold_fn = tpu_scaffold
|
scaffold_fn = tpu_scaffold
|
||||||
else:
|
else:
|
||||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
|
|
||||||
tf.logging.info("**** Trainable Variables ****")
|
tf.logging.info("**** Trainable Variables ****")
|
||||||
for var in tvars:
|
for var in tvars:
|
||||||
init_string = ""
|
init_string = ""
|
||||||
if var.name in initialized_variable_names:
|
if var.name in initialized_variable_names:
|
||||||
init_string = ", *INIT_FROM_CKPT*"
|
init_string = ", *INIT_FROM_CKPT*"
|
||||||
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
||||||
init_string)
|
init_string)
|
||||||
|
|
||||||
output_spec = None
|
output_spec = None
|
||||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
train_op = optimization.create_optimizer(
|
train_op = optimization.create_optimizer(
|
||||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||||
|
|
||||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
loss=total_loss,
|
loss=total_loss,
|
||||||
train_op=train_op,
|
train_op=train_op,
|
||||||
scaffold_fn=scaffold_fn)
|
scaffold_fn=scaffold_fn)
|
||||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
|
||||||
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||||
masked_lm_weights, next_sentence_example_loss,
|
masked_lm_weights, next_sentence_example_loss,
|
||||||
next_sentence_log_probs, next_sentence_labels):
|
next_sentence_log_probs, next_sentence_labels):
|
||||||
"""Computes the loss and accuracy of the model."""
|
"""Computes the loss and accuracy of the model."""
|
||||||
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
|
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
|
||||||
[-1, masked_lm_log_probs.shape[-1]])
|
[-1, masked_lm_log_probs.shape[-1]])
|
||||||
masked_lm_predictions = tf.argmax(
|
masked_lm_predictions = tf.argmax(
|
||||||
masked_lm_log_probs, axis=-1, output_type=tf.int32)
|
masked_lm_log_probs, axis=-1, output_type=tf.int32)
|
||||||
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
||||||
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
||||||
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
||||||
masked_lm_accuracy = tf.metrics.accuracy(
|
masked_lm_accuracy = tf.metrics.accuracy(
|
||||||
labels=masked_lm_ids,
|
labels=masked_lm_ids,
|
||||||
predictions=masked_lm_predictions,
|
predictions=masked_lm_predictions,
|
||||||
weights=masked_lm_weights)
|
weights=masked_lm_weights)
|
||||||
masked_lm_mean_loss = tf.metrics.mean(
|
masked_lm_mean_loss = tf.metrics.mean(
|
||||||
values=masked_lm_example_loss, weights=masked_lm_weights)
|
values=masked_lm_example_loss, weights=masked_lm_weights)
|
||||||
|
|
||||||
next_sentence_log_probs = tf.reshape(
|
next_sentence_log_probs = tf.reshape(
|
||||||
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
|
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
|
||||||
next_sentence_predictions = tf.argmax(
|
next_sentence_predictions = tf.argmax(
|
||||||
next_sentence_log_probs, axis=-1, output_type=tf.int32)
|
next_sentence_log_probs, axis=-1, output_type=tf.int32)
|
||||||
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
||||||
next_sentence_accuracy = tf.metrics.accuracy(
|
next_sentence_accuracy = tf.metrics.accuracy(
|
||||||
labels=next_sentence_labels, predictions=next_sentence_predictions)
|
labels=next_sentence_labels, predictions=next_sentence_predictions)
|
||||||
next_sentence_mean_loss = tf.metrics.mean(
|
next_sentence_mean_loss = tf.metrics.mean(
|
||||||
values=next_sentence_example_loss)
|
values=next_sentence_example_loss)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"masked_lm_accuracy": masked_lm_accuracy,
|
"masked_lm_accuracy": masked_lm_accuracy,
|
||||||
"masked_lm_loss": masked_lm_mean_loss,
|
"masked_lm_loss": masked_lm_mean_loss,
|
||||||
"next_sentence_accuracy": next_sentence_accuracy,
|
"next_sentence_accuracy": next_sentence_accuracy,
|
||||||
"next_sentence_loss": next_sentence_mean_loss,
|
"next_sentence_loss": next_sentence_mean_loss,
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_metrics = (metric_fn, [
|
eval_metrics = (metric_fn, [
|
||||||
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||||
masked_lm_weights, next_sentence_example_loss,
|
masked_lm_weights, next_sentence_example_loss,
|
||||||
next_sentence_log_probs, next_sentence_labels
|
next_sentence_log_probs, next_sentence_labels
|
||||||
])
|
])
|
||||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
loss=total_loss,
|
loss=total_loss,
|
||||||
eval_metrics=eval_metrics,
|
eval_metrics=eval_metrics,
|
||||||
scaffold_fn=scaffold_fn)
|
scaffold_fn=scaffold_fn)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
||||||
|
|
||||||
return output_spec
|
return output_spec
|
||||||
|
|
||||||
return model_fn
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
||||||
label_ids, label_weights):
|
label_ids, label_weights):
|
||||||
"""Get loss and log probs for the masked LM."""
|
"""Get loss and log probs for the masked LM."""
|
||||||
input_tensor = gather_indexes(input_tensor, positions)
|
input_tensor = gather_indexes(input_tensor, positions)
|
||||||
|
|
||||||
with tf.variable_scope("cls/predictions"):
|
with tf.variable_scope("cls/predictions"):
|
||||||
# We apply one more non-linear transformation before the output layer.
|
# We apply one more non-linear transformation before the output layer.
|
||||||
# This matrix is not used after pre-training.
|
# This matrix is not used after pre-training.
|
||||||
with tf.variable_scope("transform"):
|
with tf.variable_scope("transform"):
|
||||||
input_tensor = tf.layers.dense(
|
input_tensor = tf.layers.dense(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
units=bert_config.hidden_size,
|
units=bert_config.hidden_size,
|
||||||
activation=modeling.get_activation(bert_config.hidden_act),
|
activation=modeling.get_activation(bert_config.hidden_act),
|
||||||
kernel_initializer=modeling.create_initializer(
|
kernel_initializer=modeling.create_initializer(
|
||||||
bert_config.initializer_range))
|
bert_config.initializer_range))
|
||||||
input_tensor = modeling.layer_norm(input_tensor)
|
input_tensor = modeling.layer_norm(input_tensor)
|
||||||
|
|
||||||
# The output weights are the same as the input embeddings, but there is
|
# The output weights are the same as the input embeddings, but there is
|
||||||
# an output-only bias for each token.
|
# an output-only bias for each token.
|
||||||
output_bias = tf.get_variable(
|
output_bias = tf.get_variable(
|
||||||
"output_bias",
|
"output_bias",
|
||||||
shape=[bert_config.vocab_size],
|
shape=[bert_config.vocab_size],
|
||||||
initializer=tf.zeros_initializer())
|
initializer=tf.zeros_initializer())
|
||||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||||
logits = tf.nn.bias_add(logits, output_bias)
|
logits = tf.nn.bias_add(logits, output_bias)
|
||||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||||
|
|
||||||
label_ids = tf.reshape(label_ids, [-1])
|
label_ids = tf.reshape(label_ids, [-1])
|
||||||
label_weights = tf.reshape(label_weights, [-1])
|
label_weights = tf.reshape(label_weights, [-1])
|
||||||
|
|
||||||
one_hot_labels = tf.one_hot(
|
one_hot_labels = tf.one_hot(
|
||||||
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
|
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
|
||||||
|
|
||||||
# The `positions` tensor might be zero-padded (if the sequence is too
|
# The `positions` tensor might be zero-padded (if the sequence is too
|
||||||
# short to have the maximum number of predictions). The `label_weights`
|
# short to have the maximum number of predictions). The `label_weights`
|
||||||
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
||||||
# padding predictions.
|
# padding predictions.
|
||||||
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
||||||
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
||||||
denominator = tf.reduce_sum(label_weights) + 1e-5
|
denominator = tf.reduce_sum(label_weights) + 1e-5
|
||||||
loss = numerator / denominator
|
loss = numerator / denominator
|
||||||
|
|
||||||
return (loss, per_example_loss, log_probs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_next_sentence_output(bert_config, input_tensor, labels):
|
|
||||||
"""Get loss and log probs for the next sentence prediction."""
|
|
||||||
|
|
||||||
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
|
||||||
# "random sentence". This weight matrix is not used after pre-training.
|
|
||||||
with tf.variable_scope("cls/seq_relationship"):
|
|
||||||
output_weights = tf.get_variable(
|
|
||||||
"output_weights",
|
|
||||||
shape=[2, bert_config.hidden_size],
|
|
||||||
initializer=modeling.create_initializer(bert_config.initializer_range))
|
|
||||||
output_bias = tf.get_variable(
|
|
||||||
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
|
||||||
|
|
||||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
|
||||||
logits = tf.nn.bias_add(logits, output_bias)
|
|
||||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
|
||||||
labels = tf.reshape(labels, [-1])
|
|
||||||
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
|
||||||
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
|
||||||
loss = tf.reduce_mean(per_example_loss)
|
|
||||||
return (loss, per_example_loss, log_probs)
|
return (loss, per_example_loss, log_probs)
|
||||||
|
|
||||||
|
|
||||||
def gather_indexes(sequence_tensor, positions):
|
def get_next_sentence_output(bert_config, input_tensor, labels):
|
||||||
"""Gathers the vectors at the specific positions over a minibatch."""
|
"""Get loss and log probs for the next sentence prediction."""
|
||||||
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
|
||||||
batch_size = sequence_shape[0]
|
|
||||||
seq_length = sequence_shape[1]
|
|
||||||
width = sequence_shape[2]
|
|
||||||
|
|
||||||
flat_offsets = tf.reshape(
|
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
||||||
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
# "random sentence". This weight matrix is not used after pre-training.
|
||||||
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
with tf.variable_scope("cls/seq_relationship"):
|
||||||
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
output_weights = tf.get_variable(
|
||||||
[batch_size * seq_length, width])
|
"output_weights",
|
||||||
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
shape=[2, bert_config.hidden_size],
|
||||||
return output_tensor
|
initializer=modeling.create_initializer(bert_config.initializer_range))
|
||||||
|
output_bias = tf.get_variable(
|
||||||
|
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
||||||
|
|
||||||
|
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||||
|
logits = tf.nn.bias_add(logits, output_bias)
|
||||||
|
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||||
|
labels = tf.reshape(labels, [-1])
|
||||||
|
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
||||||
|
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||||
|
loss = tf.reduce_mean(per_example_loss)
|
||||||
|
return (loss, per_example_loss, log_probs)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_indexes(sequence_tensor, positions):
|
||||||
|
"""Gathers the vectors at the specific positions over a minibatch."""
|
||||||
|
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
||||||
|
batch_size = sequence_shape[0]
|
||||||
|
seq_length = sequence_shape[1]
|
||||||
|
width = sequence_shape[2]
|
||||||
|
|
||||||
|
flat_offsets = tf.reshape(
|
||||||
|
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
||||||
|
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
||||||
|
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
||||||
|
[batch_size * seq_length, width])
|
||||||
|
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
def input_fn_builder(input_files,
|
def input_fn_builder(input_files,
|
||||||
@@ -327,168 +327,168 @@ def input_fn_builder(input_files,
|
|||||||
max_predictions_per_seq,
|
max_predictions_per_seq,
|
||||||
is_training,
|
is_training,
|
||||||
num_cpu_threads=4):
|
num_cpu_threads=4):
|
||||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||||
|
|
||||||
def input_fn(params):
|
def input_fn(params):
|
||||||
"""The actual input function."""
|
"""The actual input function."""
|
||||||
batch_size = params["batch_size"]
|
batch_size = params["batch_size"]
|
||||||
|
|
||||||
name_to_features = {
|
name_to_features = {
|
||||||
"input_ids":
|
"input_ids":
|
||||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||||
"input_mask":
|
"input_mask":
|
||||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||||
"segment_ids":
|
"segment_ids":
|
||||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||||
"masked_lm_positions":
|
"masked_lm_positions":
|
||||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||||
"masked_lm_ids":
|
"masked_lm_ids":
|
||||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||||
"masked_lm_weights":
|
"masked_lm_weights":
|
||||||
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||||
"next_sentence_labels":
|
"next_sentence_labels":
|
||||||
tf.FixedLenFeature([1], tf.int64),
|
tf.FixedLenFeature([1], tf.int64),
|
||||||
}
|
}
|
||||||
|
|
||||||
# For training, we want a lot of parallel reading and shuffling.
|
# For training, we want a lot of parallel reading and shuffling.
|
||||||
# For eval, we want no shuffling and parallel reading doesn't matter.
|
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||||
if is_training:
|
if is_training:
|
||||||
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
||||||
d = d.repeat()
|
d = d.repeat()
|
||||||
d = d.shuffle(buffer_size=len(input_files))
|
d = d.shuffle(buffer_size=len(input_files))
|
||||||
|
|
||||||
# `cycle_length` is the number of parallel files that get read.
|
# `cycle_length` is the number of parallel files that get read.
|
||||||
cycle_length = min(num_cpu_threads, len(input_files))
|
cycle_length = min(num_cpu_threads, len(input_files))
|
||||||
|
|
||||||
# `sloppy` mode means that the interleaving is not exact. This adds
|
# `sloppy` mode means that the interleaving is not exact. This adds
|
||||||
# even more randomness to the training pipeline.
|
# even more randomness to the training pipeline.
|
||||||
d = d.apply(
|
d = d.apply(
|
||||||
tf.contrib.data.parallel_interleave(
|
tf.contrib.data.parallel_interleave(
|
||||||
tf.data.TFRecordDataset,
|
tf.data.TFRecordDataset,
|
||||||
sloppy=is_training,
|
sloppy=is_training,
|
||||||
cycle_length=cycle_length))
|
cycle_length=cycle_length))
|
||||||
d = d.shuffle(buffer_size=100)
|
d = d.shuffle(buffer_size=100)
|
||||||
else:
|
else:
|
||||||
d = tf.data.TFRecordDataset(input_files)
|
d = tf.data.TFRecordDataset(input_files)
|
||||||
# Since we evaluate for a fixed number of steps we don't want to encounter
|
# Since we evaluate for a fixed number of steps we don't want to encounter
|
||||||
# out-of-range exceptions.
|
# out-of-range exceptions.
|
||||||
d = d.repeat()
|
d = d.repeat()
|
||||||
|
|
||||||
# We must `drop_remainder` on training because the TPU requires fixed
|
# We must `drop_remainder` on training because the TPU requires fixed
|
||||||
# size dimensions. For eval, we assume we are evaling on the CPU or GPU
|
# size dimensions. For eval, we assume we are evaling on the CPU or GPU
|
||||||
# and we *don"t* want to drop the remainder, otherwise we wont cover
|
# and we *don"t* want to drop the remainder, otherwise we wont cover
|
||||||
# every sample.
|
# every sample.
|
||||||
d = d.apply(
|
d = d.apply(
|
||||||
tf.contrib.data.map_and_batch(
|
tf.contrib.data.map_and_batch(
|
||||||
lambda record: _decode_record(record, name_to_features),
|
lambda record: _decode_record(record, name_to_features),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_parallel_batches=num_cpu_threads,
|
num_parallel_batches=num_cpu_threads,
|
||||||
drop_remainder=True))
|
drop_remainder=True))
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
|
||||||
def _decode_record(record, name_to_features):
|
def _decode_record(record, name_to_features):
|
||||||
"""Decodes a record to a TensorFlow example."""
|
"""Decodes a record to a TensorFlow example."""
|
||||||
example = tf.parse_single_example(record, name_to_features)
|
example = tf.parse_single_example(record, name_to_features)
|
||||||
|
|
||||||
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||||
# So cast all int64 to int32.
|
# So cast all int64 to int32.
|
||||||
for name in list(example.keys()):
|
for name in list(example.keys()):
|
||||||
t = example[name]
|
t = example[name]
|
||||||
if t.dtype == tf.int64:
|
if t.dtype == tf.int64:
|
||||||
t = tf.to_int32(t)
|
t = tf.to_int32(t)
|
||||||
example[name] = t
|
example[name] = t
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
if not FLAGS.do_train and not FLAGS.do_eval:
|
if not FLAGS.do_train and not FLAGS.do_eval:
|
||||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||||
|
|
||||||
tf.gfile.MakeDirs(FLAGS.output_dir)
|
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||||
|
|
||||||
input_files = []
|
input_files = []
|
||||||
for input_pattern in FLAGS.input_file.split(","):
|
for input_pattern in FLAGS.input_file.split(","):
|
||||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||||
|
|
||||||
tf.logging.info("*** Input Files ***")
|
tf.logging.info("*** Input Files ***")
|
||||||
for input_file in input_files:
|
for input_file in input_files:
|
||||||
tf.logging.info(" %s" % input_file)
|
tf.logging.info(" %s" % input_file)
|
||||||
|
|
||||||
tpu_cluster_resolver = None
|
tpu_cluster_resolver = None
|
||||||
if FLAGS.use_tpu and FLAGS.tpu_name:
|
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||||
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||||
|
|
||||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||||
run_config = tf.contrib.tpu.RunConfig(
|
run_config = tf.contrib.tpu.RunConfig(
|
||||||
cluster=tpu_cluster_resolver,
|
cluster=tpu_cluster_resolver,
|
||||||
master=FLAGS.master,
|
master=FLAGS.master,
|
||||||
model_dir=FLAGS.output_dir,
|
model_dir=FLAGS.output_dir,
|
||||||
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||||
iterations_per_loop=FLAGS.iterations_per_loop,
|
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||||
num_shards=FLAGS.num_tpu_cores,
|
num_shards=FLAGS.num_tpu_cores,
|
||||||
per_host_input_for_training=is_per_host))
|
per_host_input_for_training=is_per_host))
|
||||||
|
|
||||||
model_fn = model_fn_builder(
|
model_fn = model_fn_builder(
|
||||||
bert_config=bert_config,
|
bert_config=bert_config,
|
||||||
init_checkpoint=FLAGS.init_checkpoint,
|
init_checkpoint=FLAGS.init_checkpoint,
|
||||||
learning_rate=FLAGS.learning_rate,
|
learning_rate=FLAGS.learning_rate,
|
||||||
num_train_steps=FLAGS.num_train_steps,
|
num_train_steps=FLAGS.num_train_steps,
|
||||||
num_warmup_steps=FLAGS.num_warmup_steps,
|
num_warmup_steps=FLAGS.num_warmup_steps,
|
||||||
use_tpu=FLAGS.use_tpu,
|
use_tpu=FLAGS.use_tpu,
|
||||||
use_one_hot_embeddings=FLAGS.use_tpu)
|
use_one_hot_embeddings=FLAGS.use_tpu)
|
||||||
|
|
||||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
# or GPU.
|
# or GPU.
|
||||||
estimator = tf.contrib.tpu.TPUEstimator(
|
estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
use_tpu=FLAGS.use_tpu,
|
use_tpu=FLAGS.use_tpu,
|
||||||
model_fn=model_fn,
|
model_fn=model_fn,
|
||||||
config=run_config,
|
config=run_config,
|
||||||
train_batch_size=FLAGS.train_batch_size,
|
train_batch_size=FLAGS.train_batch_size,
|
||||||
eval_batch_size=FLAGS.eval_batch_size)
|
eval_batch_size=FLAGS.eval_batch_size)
|
||||||
|
|
||||||
if FLAGS.do_train:
|
if FLAGS.do_train:
|
||||||
tf.logging.info("***** Running training *****")
|
tf.logging.info("***** Running training *****")
|
||||||
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||||
train_input_fn = input_fn_builder(
|
train_input_fn = input_fn_builder(
|
||||||
input_files=input_files,
|
input_files=input_files,
|
||||||
max_seq_length=FLAGS.max_seq_length,
|
max_seq_length=FLAGS.max_seq_length,
|
||||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||||
is_training=True)
|
is_training=True)
|
||||||
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
||||||
|
|
||||||
if FLAGS.do_eval:
|
if FLAGS.do_eval:
|
||||||
tf.logging.info("***** Running evaluation *****")
|
tf.logging.info("***** Running evaluation *****")
|
||||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||||
|
|
||||||
eval_input_fn = input_fn_builder(
|
eval_input_fn = input_fn_builder(
|
||||||
input_files=input_files,
|
input_files=input_files,
|
||||||
max_seq_length=FLAGS.max_seq_length,
|
max_seq_length=FLAGS.max_seq_length,
|
||||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||||
is_training=False)
|
is_training=False)
|
||||||
|
|
||||||
result = estimator.evaluate(
|
result = estimator.evaluate(
|
||||||
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
||||||
|
|
||||||
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
||||||
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
||||||
tf.logging.info("***** Eval results *****")
|
tf.logging.info("***** Eval results *****")
|
||||||
for key in sorted(result.keys()):
|
for key in sorted(result.keys()):
|
||||||
tf.logging.info(" %s = %s", key, str(result[key]))
|
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
flags.mark_flag_as_required("input_file")
|
flags.mark_flag_as_required("input_file")
|
||||||
flags.mark_flag_as_required("bert_config_file")
|
flags.mark_flag_as_required("bert_config_file")
|
||||||
flags.mark_flag_as_required("output_dir")
|
flags.mark_flag_as_required("output_dir")
|
||||||
tf.app.run()
|
tf.app.run()
|
||||||
|
|||||||
1606
run_squad.py
1606
run_squad.py
File diff suppressed because it is too large
Load Diff
416
tokenization.py
416
tokenization.py
@@ -25,268 +25,268 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
def convert_to_unicode(text):
|
def convert_to_unicode(text):
|
||||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||||
if six.PY3:
|
if six.PY3:
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
return text
|
return text
|
||||||
elif isinstance(text, bytes):
|
elif isinstance(text, bytes):
|
||||||
return text.decode("utf-8", "ignore")
|
return text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
elif six.PY2:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
elif isinstance(text, unicode):
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
raise ValueError("Not running on Python2 or Python 3?")
|
||||||
elif six.PY2:
|
|
||||||
if isinstance(text, str):
|
|
||||||
return text.decode("utf-8", "ignore")
|
|
||||||
elif isinstance(text, unicode):
|
|
||||||
return text
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
|
||||||
else:
|
|
||||||
raise ValueError("Not running on Python2 or Python 3?")
|
|
||||||
|
|
||||||
|
|
||||||
def printable_text(text):
|
def printable_text(text):
|
||||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||||
|
|
||||||
# These functions want `str` for both Python2 and Python3, but in one case
|
# These functions want `str` for both Python2 and Python3, but in one case
|
||||||
# it's a Unicode string and in the other it's a byte string.
|
# it's a Unicode string and in the other it's a byte string.
|
||||||
if six.PY3:
|
if six.PY3:
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
return text
|
return text
|
||||||
elif isinstance(text, bytes):
|
elif isinstance(text, bytes):
|
||||||
return text.decode("utf-8", "ignore")
|
return text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
elif six.PY2:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, unicode):
|
||||||
|
return text.encode("utf-8")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
raise ValueError("Not running on Python2 or Python 3?")
|
||||||
elif six.PY2:
|
|
||||||
if isinstance(text, str):
|
|
||||||
return text
|
|
||||||
elif isinstance(text, unicode):
|
|
||||||
return text.encode("utf-8")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
|
||||||
else:
|
|
||||||
raise ValueError("Not running on Python2 or Python 3?")
|
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(vocab_file):
|
def load_vocab(vocab_file):
|
||||||
"""Loads a vocabulary file into a dictionary."""
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
vocab = collections.OrderedDict()
|
vocab = collections.OrderedDict()
|
||||||
index = 0
|
index = 0
|
||||||
with tf.gfile.GFile(vocab_file, "r") as reader:
|
with tf.gfile.GFile(vocab_file, "r") as reader:
|
||||||
while True:
|
while True:
|
||||||
token = convert_to_unicode(reader.readline())
|
token = convert_to_unicode(reader.readline())
|
||||||
if not token:
|
if not token:
|
||||||
break
|
break
|
||||||
token = token.strip()
|
token = token.strip()
|
||||||
vocab[token] = index
|
vocab[token] = index
|
||||||
index += 1
|
index += 1
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
def convert_tokens_to_ids(vocab, tokens):
|
def convert_tokens_to_ids(vocab, tokens):
|
||||||
"""Converts a sequence of tokens into ids using the vocab."""
|
"""Converts a sequence of tokens into ids using the vocab."""
|
||||||
ids = []
|
ids = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
ids.append(vocab[token])
|
ids.append(vocab[token])
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
def whitespace_tokenize(text):
|
def whitespace_tokenize(text):
|
||||||
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return []
|
||||||
tokens = text.split()
|
tokens = text.split()
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
class FullTokenizer(object):
|
class FullTokenizer(object):
|
||||||
"""Runs end-to-end tokenziation."""
|
"""Runs end-to-end tokenziation."""
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True):
|
def __init__(self, vocab_file, do_lower_case=True):
|
||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
for token in self.basic_tokenizer.tokenize(text):
|
for token in self.basic_tokenizer.tokenize(text):
|
||||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||||
split_tokens.append(sub_token)
|
split_tokens.append(sub_token)
|
||||||
|
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
return convert_tokens_to_ids(self.vocab, tokens)
|
return convert_tokens_to_ids(self.vocab, tokens)
|
||||||
|
|
||||||
|
|
||||||
class BasicTokenizer(object):
|
class BasicTokenizer(object):
|
||||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||||
|
|
||||||
def __init__(self, do_lower_case=True):
|
def __init__(self, do_lower_case=True):
|
||||||
"""Constructs a BasicTokenizer.
|
"""Constructs a BasicTokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
do_lower_case: Whether to lower case the input.
|
do_lower_case: Whether to lower case the input.
|
||||||
"""
|
"""
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
"""Tokenizes a piece of text."""
|
"""Tokenizes a piece of text."""
|
||||||
text = convert_to_unicode(text)
|
text = convert_to_unicode(text)
|
||||||
text = self._clean_text(text)
|
text = self._clean_text(text)
|
||||||
orig_tokens = whitespace_tokenize(text)
|
orig_tokens = whitespace_tokenize(text)
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
for token in orig_tokens:
|
for token in orig_tokens:
|
||||||
if self.do_lower_case:
|
if self.do_lower_case:
|
||||||
token = token.lower()
|
token = token.lower()
|
||||||
token = self._run_strip_accents(token)
|
token = self._run_strip_accents(token)
|
||||||
split_tokens.extend(self._run_split_on_punc(token))
|
split_tokens.extend(self._run_split_on_punc(token))
|
||||||
|
|
||||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
def _run_strip_accents(self, text):
|
def _run_strip_accents(self, text):
|
||||||
"""Strips accents from a piece of text."""
|
"""Strips accents from a piece of text."""
|
||||||
text = unicodedata.normalize("NFD", text)
|
text = unicodedata.normalize("NFD", text)
|
||||||
output = []
|
output = []
|
||||||
for char in text:
|
for char in text:
|
||||||
cat = unicodedata.category(char)
|
cat = unicodedata.category(char)
|
||||||
if cat == "Mn":
|
if cat == "Mn":
|
||||||
continue
|
continue
|
||||||
output.append(char)
|
output.append(char)
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
def _run_split_on_punc(self, text):
|
def _run_split_on_punc(self, text):
|
||||||
"""Splits punctuation on a piece of text."""
|
"""Splits punctuation on a piece of text."""
|
||||||
chars = list(text)
|
chars = list(text)
|
||||||
i = 0
|
i = 0
|
||||||
start_new_word = True
|
|
||||||
output = []
|
|
||||||
while i < len(chars):
|
|
||||||
char = chars[i]
|
|
||||||
if _is_punctuation(char):
|
|
||||||
output.append([char])
|
|
||||||
start_new_word = True
|
start_new_word = True
|
||||||
else:
|
output = []
|
||||||
if start_new_word:
|
while i < len(chars):
|
||||||
output.append([])
|
char = chars[i]
|
||||||
start_new_word = False
|
if _is_punctuation(char):
|
||||||
output[-1].append(char)
|
output.append([char])
|
||||||
i += 1
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
return ["".join(x) for x in output]
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
def _clean_text(self, text):
|
def _clean_text(self, text):
|
||||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||||
output = []
|
output = []
|
||||||
for char in text:
|
for char in text:
|
||||||
cp = ord(char)
|
cp = ord(char)
|
||||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||||
continue
|
continue
|
||||||
if _is_whitespace(char):
|
if _is_whitespace(char):
|
||||||
output.append(" ")
|
output.append(" ")
|
||||||
else:
|
else:
|
||||||
output.append(char)
|
output.append(char)
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
class WordpieceTokenizer(object):
|
class WordpieceTokenizer(object):
|
||||||
"""Runs WordPiece tokenziation."""
|
"""Runs WordPiece tokenziation."""
|
||||||
|
|
||||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.unk_token = unk_token
|
self.unk_token = unk_token
|
||||||
self.max_input_chars_per_word = max_input_chars_per_word
|
self.max_input_chars_per_word = max_input_chars_per_word
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
"""Tokenizes a piece of text into its word pieces.
|
"""Tokenizes a piece of text into its word pieces.
|
||||||
|
|
||||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||||
using the given vocabulary.
|
using the given vocabulary.
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
input = "unaffable"
|
input = "unaffable"
|
||||||
output = ["un", "##aff", "##able"]
|
output = ["un", "##aff", "##able"]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: A single token or whitespace separated tokens. This should have
|
text: A single token or whitespace separated tokens. This should have
|
||||||
already been passed through `BasicTokenizer.
|
already been passed through `BasicTokenizer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of wordpiece tokens.
|
A list of wordpiece tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text = convert_to_unicode(text)
|
text = convert_to_unicode(text)
|
||||||
|
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
for token in whitespace_tokenize(text):
|
for token in whitespace_tokenize(text):
|
||||||
chars = list(token)
|
chars = list(token)
|
||||||
if len(chars) > self.max_input_chars_per_word:
|
if len(chars) > self.max_input_chars_per_word:
|
||||||
output_tokens.append(self.unk_token)
|
output_tokens.append(self.unk_token)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
is_bad = False
|
is_bad = False
|
||||||
start = 0
|
start = 0
|
||||||
sub_tokens = []
|
sub_tokens = []
|
||||||
while start < len(chars):
|
while start < len(chars):
|
||||||
end = len(chars)
|
end = len(chars)
|
||||||
cur_substr = None
|
cur_substr = None
|
||||||
while start < end:
|
while start < end:
|
||||||
substr = "".join(chars[start:end])
|
substr = "".join(chars[start:end])
|
||||||
if start > 0:
|
if start > 0:
|
||||||
substr = "##" + substr
|
substr = "##" + substr
|
||||||
if substr in self.vocab:
|
if substr in self.vocab:
|
||||||
cur_substr = substr
|
cur_substr = substr
|
||||||
break
|
break
|
||||||
end -= 1
|
end -= 1
|
||||||
if cur_substr is None:
|
if cur_substr is None:
|
||||||
is_bad = True
|
is_bad = True
|
||||||
break
|
break
|
||||||
sub_tokens.append(cur_substr)
|
sub_tokens.append(cur_substr)
|
||||||
start = end
|
start = end
|
||||||
|
|
||||||
if is_bad:
|
if is_bad:
|
||||||
output_tokens.append(self.unk_token)
|
output_tokens.append(self.unk_token)
|
||||||
else:
|
else:
|
||||||
output_tokens.extend(sub_tokens)
|
output_tokens.extend(sub_tokens)
|
||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
def _is_whitespace(char):
|
def _is_whitespace(char):
|
||||||
"""Checks whether `chars` is a whitespace character."""
|
"""Checks whether `chars` is a whitespace character."""
|
||||||
# \t, \n, and \r are technically contorl characters but we treat them
|
# \t, \n, and \r are technically contorl characters but we treat them
|
||||||
# as whitespace since they are generally considered as such.
|
# as whitespace since they are generally considered as such.
|
||||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||||
return True
|
return True
|
||||||
cat = unicodedata.category(char)
|
cat = unicodedata.category(char)
|
||||||
if cat == "Zs":
|
if cat == "Zs":
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _is_control(char):
|
def _is_control(char):
|
||||||
"""Checks whether `chars` is a control character."""
|
"""Checks whether `chars` is a control character."""
|
||||||
# These are technically control characters but we count them as whitespace
|
# These are technically control characters but we count them as whitespace
|
||||||
# characters.
|
# characters.
|
||||||
if char == "\t" or char == "\n" or char == "\r":
|
if char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return False
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("C"):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat.startswith("C"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_punctuation(char):
|
def _is_punctuation(char):
|
||||||
"""Checks whether `chars` is a punctuation character."""
|
"""Checks whether `chars` is a punctuation character."""
|
||||||
cp = ord(char)
|
cp = ord(char)
|
||||||
# We treat all non-letter/number ASCII as punctuation.
|
# We treat all non-letter/number ASCII as punctuation.
|
||||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||||
# Punctuation class but we treat them as punctuation anyways, for
|
# Punctuation class but we treat them as punctuation anyways, for
|
||||||
# consistency.
|
# consistency.
|
||||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||||
return True
|
return True
|
||||||
cat = unicodedata.category(char)
|
cat = unicodedata.category(char)
|
||||||
if cat.startswith("P"):
|
if cat.startswith("P"):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -25,101 +25,101 @@ import tensorflow as tf
|
|||||||
|
|
||||||
class TokenizationTest(tf.test.TestCase):
|
class TokenizationTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing", ","
|
"##ing", ","
|
||||||
]
|
]
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
vocab_file = vocab_writer.name
|
vocab_file = vocab_writer.name
|
||||||
|
|
||||||
tokenizer = tokenization.FullTokenizer(vocab_file)
|
tokenizer = tokenization.FullTokenizer(vocab_file)
|
||||||
os.unlink(vocab_file)
|
os.unlink(vocab_file)
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||||
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
def test_basic_tokenizer_lower(self):
|
def test_basic_tokenizer_lower(self):
|
||||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
["hello", "!", "how", "are", "you", "?"])
|
["hello", "!", "how", "are", "you", "?"])
|
||||||
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||||
|
|
||||||
def test_basic_tokenizer_no_lower(self):
|
def test_basic_tokenizer_no_lower(self):
|
||||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||||
|
|
||||||
def test_wordpiece_tokenizer(self):
|
def test_wordpiece_tokenizer(self):
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing"
|
"##ing"
|
||||||
]
|
]
|
||||||
|
|
||||||
vocab = {}
|
vocab = {}
|
||||||
for (i, token) in enumerate(vocab_tokens):
|
for (i, token) in enumerate(vocab_tokens):
|
||||||
vocab[token] = i
|
vocab[token] = i
|
||||||
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
||||||
|
|
||||||
self.assertAllEqual(tokenizer.tokenize(""), [])
|
self.assertAllEqual(tokenizer.tokenize(""), [])
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
tokenizer.tokenize("unwanted running"),
|
tokenizer.tokenize("unwanted running"),
|
||||||
["un", "##want", "##ed", "runn", "##ing"])
|
["un", "##want", "##ed", "runn", "##ing"])
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||||
|
|
||||||
def test_convert_tokens_to_ids(self):
|
def test_convert_tokens_to_ids(self):
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing"
|
"##ing"
|
||||||
]
|
]
|
||||||
|
|
||||||
vocab = {}
|
vocab = {}
|
||||||
for (i, token) in enumerate(vocab_tokens):
|
for (i, token) in enumerate(vocab_tokens):
|
||||||
vocab[token] = i
|
vocab[token] = i
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
tokenization.convert_tokens_to_ids(
|
tokenization.convert_tokens_to_ids(
|
||||||
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
||||||
|
|
||||||
def test_is_whitespace(self):
|
def test_is_whitespace(self):
|
||||||
self.assertTrue(tokenization._is_whitespace(u" "))
|
self.assertTrue(tokenization._is_whitespace(u" "))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
||||||
|
|
||||||
self.assertFalse(tokenization._is_whitespace(u"A"))
|
self.assertFalse(tokenization._is_whitespace(u"A"))
|
||||||
self.assertFalse(tokenization._is_whitespace(u"-"))
|
self.assertFalse(tokenization._is_whitespace(u"-"))
|
||||||
|
|
||||||
def test_is_control(self):
|
def test_is_control(self):
|
||||||
self.assertTrue(tokenization._is_control(u"\u0005"))
|
self.assertTrue(tokenization._is_control(u"\u0005"))
|
||||||
|
|
||||||
self.assertFalse(tokenization._is_control(u"A"))
|
self.assertFalse(tokenization._is_control(u"A"))
|
||||||
self.assertFalse(tokenization._is_control(u" "))
|
self.assertFalse(tokenization._is_control(u" "))
|
||||||
self.assertFalse(tokenization._is_control(u"\t"))
|
self.assertFalse(tokenization._is_control(u"\t"))
|
||||||
self.assertFalse(tokenization._is_control(u"\r"))
|
self.assertFalse(tokenization._is_control(u"\r"))
|
||||||
|
|
||||||
def test_is_punctuation(self):
|
def test_is_punctuation(self):
|
||||||
self.assertTrue(tokenization._is_punctuation(u"-"))
|
self.assertTrue(tokenization._is_punctuation(u"-"))
|
||||||
self.assertTrue(tokenization._is_punctuation(u"$"))
|
self.assertTrue(tokenization._is_punctuation(u"$"))
|
||||||
self.assertTrue(tokenization._is_punctuation(u"`"))
|
self.assertTrue(tokenization._is_punctuation(u"`"))
|
||||||
self.assertTrue(tokenization._is_punctuation(u"."))
|
self.assertTrue(tokenization._is_punctuation(u"."))
|
||||||
|
|
||||||
self.assertFalse(tokenization._is_punctuation(u"A"))
|
self.assertFalse(tokenization._is_punctuation(u"A"))
|
||||||
self.assertFalse(tokenization._is_punctuation(u" "))
|
self.assertFalse(tokenization._is_punctuation(u" "))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user