update tokenizer - update squad example for xlnet
This commit is contained in:
@@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||||
'dev' if evaluate else 'train',
|
'dev' if evaluate else 'train',
|
||||||
list(filter(None, args.model_name.split('/'))).pop(),
|
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task)))
|
str(task)))
|
||||||
if os.path.exists(cached_features_file):
|
if os.path.exists(cached_features_file):
|
||||||
@@ -282,8 +282,10 @@ def main():
|
|||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||||
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
@@ -400,15 +402,11 @@ def main():
|
|||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
args.model_type = ""
|
args.model_type = args.model_type.lower()
|
||||||
for key in MODEL_CLASSES:
|
|
||||||
if key in args.model_name.lower():
|
|
||||||
args.model_type = key # take the first match in model types
|
|
||||||
break
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|||||||
@@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
inputs.update({'cls_index': batch[4],
|
inputs.update({'cls_index': batch[4],
|
||||||
'p_mask': batch[5]})
|
'p_mask': batch[5]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
batch_start_logits, batch_end_logits = outputs[:2]
|
|
||||||
|
|
||||||
for i, example_index in enumerate(example_indices):
|
for i, example_index in enumerate(example_indices):
|
||||||
eval_feature = features[example_index.item()]
|
eval_feature = features[example_index.item()]
|
||||||
@@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, output_prediction_file,
|
args.max_answer_length, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||||
args.start_n_top, args.end_n_top, args.version_2_with_negative)
|
model.config.start_n_top, model.config.end_n_top,
|
||||||
|
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||||
else:
|
else:
|
||||||
write_predictions(examples, features, all_results, args.n_best_size,
|
write_predictions(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||||
@@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
input_file = args.predict_file if evaluate else args.train_file
|
input_file = args.predict_file if evaluate else args.train_file
|
||||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||||
'dev' if evaluate else 'train',
|
'dev' if evaluate else 'train',
|
||||||
list(filter(None, args.model_name.split('/'))).pop(),
|
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||||
str(args.max_seq_length)))
|
str(args.max_seq_length)))
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
@@ -312,8 +312,10 @@ def main():
|
|||||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||||
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
help="The output directory where the model checkpoints and predictions will be written.")
|
help="The output directory where the model checkpoints and predictions will be written.")
|
||||||
|
|
||||||
@@ -438,15 +440,11 @@ def main():
|
|||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
args.model_type = ""
|
args.model_type = args.model_type.lower()
|
||||||
for key in MODEL_CLASSES:
|
|
||||||
if key in args.model_name.lower():
|
|
||||||
args.model_type = key # take the first match in model types
|
|
||||||
break
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|||||||
@@ -60,8 +60,9 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
"--warmup_steps=2",
|
"--warmup_steps=2",
|
||||||
"--overwrite_output_dir",
|
"--overwrite_output_dir",
|
||||||
"--seed=42"]
|
"--seed=42"]
|
||||||
model_name = "--model_name=bert-base-uncased"
|
model_type, model_name = ("--model_type=bert",
|
||||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
"--model_name_or_path=bert-base-uncased")
|
||||||
|
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||||
result = run_glue.main()
|
result = run_glue.main()
|
||||||
for value in result.values():
|
for value in result.values():
|
||||||
self.assertGreaterEqual(value, 0.75)
|
self.assertGreaterEqual(value, 0.75)
|
||||||
@@ -85,8 +86,9 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
"--per_gpu_eval_batch_size=1",
|
"--per_gpu_eval_batch_size=1",
|
||||||
"--overwrite_output_dir",
|
"--overwrite_output_dir",
|
||||||
"--seed=42"]
|
"--seed=42"]
|
||||||
model_name = "--model_name=bert-base-uncased"
|
model_type, model_name = ("--model_type=bert",
|
||||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
"--model_name_or_path=bert-base-uncased")
|
||||||
|
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||||
result = run_squad.main()
|
result = run_squad.main()
|
||||||
self.assertGreaterEqual(result['f1'], 30)
|
self.assertGreaterEqual(result['f1'], 30)
|
||||||
self.assertGreaterEqual(result['exact'], 30)
|
self.assertGreaterEqual(result['exact'], 30)
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class InputFeatures(object):
|
|||||||
segment_ids,
|
segment_ids,
|
||||||
cls_index,
|
cls_index,
|
||||||
p_mask,
|
p_mask,
|
||||||
|
paragraph_len,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None,
|
end_position=None,
|
||||||
is_impossible=None):
|
is_impossible=None):
|
||||||
@@ -101,6 +102,7 @@ class InputFeatures(object):
|
|||||||
self.segment_ids = segment_ids
|
self.segment_ids = segment_ids
|
||||||
self.cls_index = cls_index
|
self.cls_index = cls_index
|
||||||
self.p_mask = p_mask
|
self.p_mask = p_mask
|
||||||
|
self.paragraph_len = paragraph_len
|
||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
self.is_impossible = is_impossible
|
self.is_impossible = is_impossible
|
||||||
@@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
tokens.append(all_doc_tokens[split_token_index])
|
tokens.append(all_doc_tokens[split_token_index])
|
||||||
segment_ids.append(sequence_b_segment_id)
|
segment_ids.append(sequence_b_segment_id)
|
||||||
p_mask.append(0)
|
p_mask.append(0)
|
||||||
|
paragraph_len = doc_span.length
|
||||||
|
|
||||||
# SEP token
|
# SEP token
|
||||||
tokens.append(sep_token)
|
tokens.append(sep_token)
|
||||||
@@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
segment_ids=segment_ids,
|
segment_ids=segment_ids,
|
||||||
cls_index=cls_index,
|
cls_index=cls_index,
|
||||||
p_mask=p_mask,
|
p_mask=p_mask,
|
||||||
|
paragraph_len=paragraph_len,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position,
|
end_position=end_position,
|
||||||
is_impossible=span_is_impossible))
|
is_impossible=span_is_impossible))
|
||||||
@@ -673,8 +677,9 @@ RawResultExtended = collections.namedtuple("RawResultExtended",
|
|||||||
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||||
max_answer_length, output_prediction_file,
|
max_answer_length, output_prediction_file,
|
||||||
output_nbest_file,
|
output_nbest_file,
|
||||||
output_null_log_odds_file, orig_data,
|
output_null_log_odds_file, orig_data_file,
|
||||||
start_n_top, end_n_top, version_2_with_negative):
|
start_n_top, end_n_top, version_2_with_negative,
|
||||||
|
tokenizer, verbose_logging):
|
||||||
""" XLNet write prediction logic (more complex than Bert's).
|
""" XLNet write prediction logic (more complex than Bert's).
|
||||||
Write final predictions to the json file and log-odds of null if needed.
|
Write final predictions to the json file and log-odds of null if needed.
|
||||||
|
|
||||||
@@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
|||||||
break
|
break
|
||||||
feature = features[pred.feature_index]
|
feature = features[pred.feature_index]
|
||||||
|
|
||||||
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
# XLNet un-tokenizer
|
||||||
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
# Let's keep it simple for now and see if we need all this later.
|
||||||
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
#
|
||||||
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||||
|
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||||
|
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||||
|
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||||
|
# paragraph_text = example.paragraph_text
|
||||||
|
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||||
|
|
||||||
paragraph_text = example.paragraph_text
|
# Previously used Bert untokenizer
|
||||||
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||||
|
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||||
|
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||||
|
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||||
|
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||||
|
|
||||||
|
# Clean whitespace
|
||||||
|
tok_text = tok_text.strip()
|
||||||
|
tok_text = " ".join(tok_text.split())
|
||||||
|
orig_text = " ".join(orig_tokens)
|
||||||
|
|
||||||
|
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
||||||
|
verbose_logging)
|
||||||
|
|
||||||
if final_text in seen_predictions:
|
if final_text in seen_predictions:
|
||||||
continue
|
continue
|
||||||
@@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
|||||||
with open(output_null_log_odds_file, "w") as writer:
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
||||||
|
orig_data = json.load(reader)["data"]
|
||||||
|
|
||||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||||
|
|||||||
@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module):
|
|||||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||||
1.0 means token should be masked.
|
1.0 means token should be masked.
|
||||||
"""
|
"""
|
||||||
slen, hsz = hidden_states.shape[-2:]
|
|
||||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||||
if start_positions is not None:
|
if start_positions is not None:
|
||||||
|
slen, hsz = hidden_states.shape[-2:]
|
||||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||||
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||||
@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module):
|
|||||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||||
for each sample
|
for each sample
|
||||||
"""
|
"""
|
||||||
slen, hsz = hidden_states.shape[-2:]
|
hsz = hidden_states.shape[-1]
|
||||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||||
if start_positions is not None:
|
if start_positions is not None:
|
||||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
|
|||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||||
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
Sequence of hidden-states at the last layer of the model.
|
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||||
**mems**:
|
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||||
list of ``torch.FloatTensor`` (one for each layer):
|
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||||
|
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||||
|
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||||
|
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||||
|
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||||
|
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||||
|
Log probabilities for the ``is_impossible`` label of the answers.
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(SQuADHead, self).__init__()
|
super(SQuADHead, self).__init__()
|
||||||
@@ -667,8 +676,8 @@ class SQuADHead(nn.Module):
|
|||||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||||
|
|
||||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
|
||||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
|||||||
@@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
1.0 means token should be masked. 0.0 mean token is not masked.
|
1.0 means token should be masked. 0.0 mean token is not masked.
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
Span-start scores (before SoftMax).
|
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||||
Span-end scores (before SoftMax).
|
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||||
|
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||||
|
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||||
|
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||||
|
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||||
|
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||||
|
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||||
|
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||||
|
Log probabilities for the ``is_impossible`` label of the answers.
|
||||||
**mems**:
|
**mems**:
|
||||||
list of ``torch.FloatTensor`` (one for each layer):
|
list of ``torch.FloatTensor`` (one for each layer):
|
||||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
@@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
loss_fct_cls = nn.BCEWithLogitsLoss()
|
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||||
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||||
|
|
||||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
|
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||||
# comparable to start_loss and end_loss
|
|
||||||
total_loss += cls_loss * 0.5
|
total_loss += cls_loss * 0.5
|
||||||
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
|
|
||||||
else:
|
outputs = (total_loss,) + outputs
|
||||||
outputs = (total_loss, start_logits, end_logits) + outputs
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# during inference, compute the end logits based on beam search
|
# during inference, compute the end logits based on beam search
|
||||||
@@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||||
|
|
||||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
|
||||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||||
@@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||||
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||||
|
|
||||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
|
||||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
|
||||||
|
|
||||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||||
|
|
||||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
|
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||||
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
|
# or (if labels are provided) (total_loss,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
|
input_text = u"UNwant\u00E9d,running"
|
||||||
|
output_text = u"unwanted, running"
|
||||||
|
|
||||||
|
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
|
||||||
|
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
|
|||||||
with open(merges_file, "w") as fp:
|
with open(merges_file, "w") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
input_text = u"lower newer"
|
||||||
|
output_text = u"lower<unk>newer"
|
||||||
|
|
||||||
|
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||||
|
|
||||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||||
text = "lower"
|
text = "lower"
|
||||||
|
|||||||
@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
with open(merges_file, "w") as fp:
|
with open(merges_file, "w") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
input_text = u"lower newer"
|
||||||
|
output_text = u"lower newer"
|
||||||
|
|
||||||
|
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
|
||||||
|
|
||||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||||
|
|
||||||
|
|||||||
@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
|
|||||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||||
|
|
||||||
text = u"He is very happy, UNwant\u00E9d,running"
|
tokens = tokenizer.tokenize(input_text)
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
ids_2 = tokenizer.encode(text)
|
ids_2 = tokenizer.encode(input_text)
|
||||||
tester.assertListEqual(ids, ids_2)
|
tester.assertListEqual(ids, ids_2)
|
||||||
|
|
||||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||||
text_2 = tokenizer.decode(ids)
|
text_2 = tokenizer.decode(ids)
|
||||||
|
|
||||||
|
tester.assertEqual(text_2, output_text)
|
||||||
|
|
||||||
tester.assertNotEqual(len(tokens_2), 0)
|
tester.assertNotEqual(len(tokens_2), 0)
|
||||||
tester.assertIsInstance(text_2, (str, unicode))
|
tester.assertIsInstance(text_2, (str, unicode))
|
||||||
|
|
||||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
||||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
|||||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
input_text = u"<unk> UNwanted , running"
|
||||||
|
output_text = u"<unk> unwanted, running"
|
||||||
|
|
||||||
|
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||||
|
|
||||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
|
|||||||
with open(merges_file, "w") as fp:
|
with open(merges_file, "w") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
input_text = u"lower newer"
|
||||||
|
output_text = u"lower newer"
|
||||||
|
|
||||||
|
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
|
||||||
|
|
||||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
|
|||||||
with TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
tokenizer.save_pretrained(tmpdirname)
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
input_text = u"This is a test"
|
||||||
|
output_text = u"This is a test"
|
||||||
|
|
||||||
|
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u'This is a test')
|
tokens = tokenizer.tokenize(u'This is a test')
|
||||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||||
|
|||||||
@@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
return self.ids_to_tokens.get(index, self.unk_token)
|
return self.ids_to_tokens.get(index, self.unk_token)
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of ids in a string."""
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
tokens = self.convert_ids_to_tokens(tokens_ids)
|
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||||
out_string = ''.join(tokens).replace(' ##', '').strip()
|
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
|
|||||||
@@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
return self.decoder.get(index)
|
return self.decoder.get(index)
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of ids in a string."""
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
text = ''.join(tokens_ids)
|
text = ''.join(tokens)
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|||||||
@@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
"""Converts an id in a token (BPE) using the vocab."""
|
"""Converts an id in a token (BPE) using the vocab."""
|
||||||
return self.decoder.get(index, self.unk_token)
|
return self.decoder.get(index, self.unk_token)
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of ids in a string."""
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
|
|||||||
@@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of ids in a string."""
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
out_string = ' '.join(tokens_ids).strip()
|
out_string = ' '.join(tokens).strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def convert_to_tensor(self, symbols):
|
def convert_to_tensor(self, symbols):
|
||||||
|
|||||||
@@ -361,52 +361,26 @@ class PreTrainedTokenizer(object):
|
|||||||
(resp.) a sequence of ids, using the vocabulary.
|
(resp.) a sequence of ids, using the vocabulary.
|
||||||
"""
|
"""
|
||||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||||
return self.convert_token_to_id_with_added_voc(tokens)
|
return self._convert_token_to_id_with_added_voc(tokens)
|
||||||
|
|
||||||
ids = []
|
ids = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
ids.append(self.convert_token_to_id_with_added_voc(token))
|
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||||
if len(ids) > self.max_len:
|
if len(ids) > self.max_len:
|
||||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||||
"indexing errors".format(len(ids), self.max_len))
|
"indexing errors".format(len(ids), self.max_len))
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
def _convert_token_to_id_with_added_voc(self, token):
|
||||||
def convert_token_to_id_with_added_voc(self, token):
|
|
||||||
if token in self.added_tokens_encoder:
|
if token in self.added_tokens_encoder:
|
||||||
return self.added_tokens_encoder[token]
|
return self.added_tokens_encoder[token]
|
||||||
return self._convert_token_to_id(token)
|
return self._convert_token_to_id(token)
|
||||||
|
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
|
||||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
|
||||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
|
||||||
"""
|
|
||||||
if isinstance(ids, int):
|
|
||||||
return self.convert_id_to_token(ids)
|
|
||||||
tokens = []
|
|
||||||
for index in ids:
|
|
||||||
if index in self.all_special_ids and skip_special_tokens:
|
|
||||||
continue
|
|
||||||
if index in self.added_tokens_decoder:
|
|
||||||
tokens.append(self.added_tokens_decoder[index])
|
|
||||||
else:
|
|
||||||
tokens.append(self._convert_id_to_token(index))
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||||
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
||||||
@@ -414,22 +388,48 @@ class PreTrainedTokenizer(object):
|
|||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||||
|
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||||
|
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||||
|
"""
|
||||||
|
if isinstance(ids, int):
|
||||||
|
if ids in self.added_tokens_decoder:
|
||||||
|
return self.added_tokens_decoder[ids]
|
||||||
|
else:
|
||||||
|
return self._convert_id_to_token(ids)
|
||||||
|
tokens = []
|
||||||
|
for index in ids:
|
||||||
|
if index in self.all_special_ids and skip_special_tokens:
|
||||||
|
continue
|
||||||
|
if index in self.added_tokens_decoder:
|
||||||
|
tokens.append(self.added_tokens_decoder[index])
|
||||||
|
else:
|
||||||
|
tokens.append(self._convert_id_to_token(index))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
""" Converts a sequence of tokens (string) in a single string.
|
||||||
|
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
|
||||||
|
but we often want to remove sub-word tokenization artifacts at the same time.
|
||||||
|
"""
|
||||||
|
return ' '.join(self.convert_ids_to_tokens(tokens))
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||||
with options to remove special tokens and clean up tokenization spaces.
|
with options to remove special tokens and clean up tokenization spaces.
|
||||||
"""
|
"""
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
text = self._convert_ids_to_string(filtered_tokens)
|
text = self.convert_tokens_to_string(filtered_tokens)
|
||||||
if clean_up_tokenization_spaces:
|
if clean_up_tokenization_spaces:
|
||||||
text = clean_up_tokenization(text)
|
text = clean_up_tokenization(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
|
||||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
|
|
||||||
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
|
|
||||||
"""
|
|
||||||
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def special_tokens_map(self):
|
def special_tokens_map(self):
|
||||||
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
||||||
|
|||||||
@@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
return self.decoder.get(index, self.unk_token)
|
return self.decoder.get(index, self.unk_token)
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of ids in a string."""
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
|
|||||||
@@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
token = token.decode('utf-8')
|
token = token.decode('utf-8')
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||||
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ')
|
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
@@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
return
|
return
|
||||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||||
copyfile(self.vocab_file, out_vocab_file)
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
return (out_vocab_file,)
|
return (out_vocab_file,)
|
||||||
|
|||||||
Reference in New Issue
Block a user