update tokenizer - update squad example for xlnet
This commit is contained in:
@@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
if os.path.exists(cached_features_file):
|
||||
@@ -282,8 +282,10 @@ def main():
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
@@ -400,15 +402,11 @@ def main():
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = ""
|
||||
for key in MODEL_CLASSES:
|
||||
if key in args.model_name.lower():
|
||||
args.model_type = key # take the first match in model types
|
||||
break
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
@@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
outputs = model(**inputs)
|
||||
batch_start_logits, batch_end_logits = outputs[:2]
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
@@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
args.start_n_top, args.end_n_top, args.version_2_with_negative)
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
@@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
@@ -312,8 +312,10 @@ def main():
|
||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
@@ -438,15 +440,11 @@ def main():
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = ""
|
||||
for key in MODEL_CLASSES:
|
||||
if key in args.model_name.lower():
|
||||
args.model_type = key # take the first match in model types
|
||||
break
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
@@ -60,8 +60,9 @@ class ExamplesTests(unittest.TestCase):
|
||||
"--warmup_steps=2",
|
||||
"--overwrite_output_dir",
|
||||
"--seed=42"]
|
||||
model_name = "--model_name=bert-base-uncased"
|
||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
||||
model_type, model_name = ("--model_type=bert",
|
||||
"--model_name_or_path=bert-base-uncased")
|
||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||
result = run_glue.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
||||
@@ -85,8 +86,9 @@ class ExamplesTests(unittest.TestCase):
|
||||
"--per_gpu_eval_batch_size=1",
|
||||
"--overwrite_output_dir",
|
||||
"--seed=42"]
|
||||
model_name = "--model_name=bert-base-uncased"
|
||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
||||
model_type, model_name = ("--model_type=bert",
|
||||
"--model_name_or_path=bert-base-uncased")
|
||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||
result = run_squad.main()
|
||||
self.assertGreaterEqual(result['f1'], 30)
|
||||
self.assertGreaterEqual(result['exact'], 30)
|
||||
|
||||
@@ -87,6 +87,7 @@ class InputFeatures(object):
|
||||
segment_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
@@ -101,6 +102,7 @@ class InputFeatures(object):
|
||||
self.segment_ids = segment_ids
|
||||
self.cls_index = cls_index
|
||||
self.p_mask = p_mask
|
||||
self.paragraph_len = paragraph_len
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
@@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
p_mask.append(0)
|
||||
paragraph_len = doc_span.length
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
@@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
segment_ids=segment_ids,
|
||||
cls_index=cls_index,
|
||||
p_mask=p_mask,
|
||||
paragraph_len=paragraph_len,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=span_is_impossible))
|
||||
@@ -673,8 +677,9 @@ RawResultExtended = collections.namedtuple("RawResultExtended",
|
||||
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file, orig_data,
|
||||
start_n_top, end_n_top, version_2_with_negative):
|
||||
output_null_log_odds_file, orig_data_file,
|
||||
start_n_top, end_n_top, version_2_with_negative,
|
||||
tokenizer, verbose_logging):
|
||||
""" XLNet write prediction logic (more complex than Bert's).
|
||||
Write final predictions to the json file and log-odds of null if needed.
|
||||
|
||||
@@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
|
||||
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||
# XLNet un-tokenizer
|
||||
# Let's keep it simple for now and see if we need all this later.
|
||||
#
|
||||
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||
# paragraph_text = example.paragraph_text
|
||||
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||
|
||||
paragraph_text = example.paragraph_text
|
||||
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||
# Previously used Bert untokenizer
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
||||
verbose_logging)
|
||||
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
@@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
with open(output_null_log_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
||||
orig_data = json.load(reader)["data"]
|
||||
|
||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
|
||||
@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module):
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||
@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
@@ -667,8 +676,8 @@ class SQuADHead(nn.Module):
|
||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
@@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
1.0 means token should be masked. 0.0 mean token is not masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
@@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
|
||||
# comparable to start_loss and end_loss
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
|
||||
else:
|
||||
outputs = (total_loss, start_logits, end_logits) + outputs
|
||||
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
@@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||
@@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
|
||||
|
||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
|
||||
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||
# or (if labels are provided) (total_loss,)
|
||||
return outputs
|
||||
|
||||
@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower<unk>newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||
text = "lower"
|
||||
|
||||
@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||
|
||||
|
||||
@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
|
||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(text)
|
||||
ids_2 = tokenizer.encode(input_text)
|
||||
tester.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
|
||||
tester.assertEqual(text_2, output_text)
|
||||
|
||||
tester.assertNotEqual(len(tokens_2), 0)
|
||||
tester.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
|
||||
@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
output_text = u"<unk> unwanted, running"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||
|
||||
|
||||
@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||
input_text = u"This is a test"
|
||||
output_text = u"This is a test"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
@@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(tokens_ids)
|
||||
out_string = ''.join(tokens).replace(' ##', '').strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
|
||||
@@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
text = ''.join(tokens_ids)
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
text = ''.join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
|
||||
@@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
"""Converts an id in a token (BPE) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
|
||||
@@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ' '.join(tokens_ids).strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).strip()
|
||||
return out_string
|
||||
|
||||
def convert_to_tensor(self, symbols):
|
||||
|
||||
@@ -361,52 +361,26 @@ class PreTrainedTokenizer(object):
|
||||
(resp.) a sequence of ids, using the vocabulary.
|
||||
"""
|
||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||
return self.convert_token_to_id_with_added_voc(tokens)
|
||||
return self._convert_token_to_id_with_added_voc(tokens)
|
||||
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.convert_token_to_id_with_added_voc(token))
|
||||
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
|
||||
def convert_token_to_id_with_added_voc(self, token):
|
||||
def _convert_token_to_id_with_added_voc(self, token):
|
||||
if token in self.added_tokens_encoder:
|
||||
return self.added_tokens_encoder[token]
|
||||
return self._convert_token_to_id(token)
|
||||
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
return self.convert_id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
if index in self.all_special_ids and skip_special_tokens:
|
||||
continue
|
||||
if index in self.added_tokens_decoder:
|
||||
tokens.append(self.added_tokens_decoder[index])
|
||||
else:
|
||||
tokens.append(self._convert_id_to_token(index))
|
||||
return tokens
|
||||
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
||||
@@ -414,22 +388,48 @@ class PreTrainedTokenizer(object):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
if ids in self.added_tokens_decoder:
|
||||
return self.added_tokens_decoder[ids]
|
||||
else:
|
||||
return self._convert_id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
if index in self.all_special_ids and skip_special_tokens:
|
||||
continue
|
||||
if index in self.added_tokens_decoder:
|
||||
tokens.append(self.added_tokens_decoder[index])
|
||||
else:
|
||||
tokens.append(self._convert_id_to_token(index))
|
||||
return tokens
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string.
|
||||
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
|
||||
but we often want to remove sub-word tokenization artifacts at the same time.
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens))
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
text = self._convert_ids_to_string(filtered_tokens)
|
||||
text = self.convert_tokens_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
|
||||
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
|
||||
|
||||
@property
|
||||
def special_tokens_map(self):
|
||||
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
||||
|
||||
@@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
|
||||
@@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
token = token.decode('utf-8')
|
||||
return token
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ')
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
@@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
return
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
Reference in New Issue
Block a user