From 15d8b1266c3a399e16a9ffe2f8e0420e3c9682a4 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Jul 2019 17:30:42 +0200 Subject: [PATCH] update tokenizer - update squad example for xlnet --- examples/run_glue.py | 20 +++--- examples/run_squad.py | 24 +++---- examples/test_examples.py | 10 +-- examples/utils_squad.py | 41 ++++++++--- pytorch_transformers/modeling_utils.py | 29 +++++--- pytorch_transformers/modeling_xlnet.py | 43 ++++++----- .../tests/tokenization_bert_test.py | 5 +- .../tests/tokenization_gpt2_test.py | 5 +- .../tests/tokenization_openai_test.py | 5 +- .../tests/tokenization_tests_commons.py | 13 ++-- .../tests/tokenization_transfo_xl_test.py | 5 +- .../tests/tokenization_xlm_test.py | 5 +- .../tests/tokenization_xlnet_test.py | 5 +- pytorch_transformers/tokenization_bert.py | 7 +- pytorch_transformers/tokenization_gpt2.py | 6 +- pytorch_transformers/tokenization_openai.py | 6 +- .../tokenization_transfo_xl.py | 6 +- pytorch_transformers/tokenization_utils.py | 72 +++++++++---------- pytorch_transformers/tokenization_xlm.py | 6 +- pytorch_transformers/tokenization_xlnet.py | 9 +-- 20 files changed, 191 insertions(+), 131 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index 979c644471..f017db2f6f 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): # Load data features from cache or dataset file cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 'dev' if evaluate else 'train', - list(filter(None, args.model_name.split('/'))).pop(), + list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task))) if os.path.exists(cached_features_file): @@ -282,8 +282,10 @@ def main(): ## Required parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") - parser.add_argument("--model_name", default=None, type=str, required=True, - help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) parser.add_argument("--output_dir", default=None, type=str, required=True, @@ -400,15 +402,11 @@ def main(): if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab - args.model_type = "" - for key in MODEL_CLASSES: - if key in args.model_name.lower(): - args.model_type = key # take the first match in model types - break + args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name) - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case) - model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config) + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab diff --git a/examples/run_squad.py b/examples/run_squad.py index 2025217454..e920ebe378 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""): inputs.update({'cls_index': batch[4], 'p_mask': batch[5]}) outputs = model(**inputs) - batch_start_logits, batch_end_logits = outputs[:2] for i, example_index in enumerate(example_indices): eval_feature = features[example_index.item()] @@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""): write_predictions_extended(examples, features, all_results, args.n_best_size, args.max_answer_length, output_prediction_file, output_nbest_file, output_null_log_odds_file, args.predict_file, - args.start_n_top, args.end_n_top, args.version_2_with_negative) + model.config.start_n_top, model.config.end_n_top, + args.version_2_with_negative, tokenizer, args.verbose_logging) else: write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, @@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal input_file = args.predict_file if evaluate else args.train_file cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( 'dev' if evaluate else 'train', - list(filter(None, args.model_name.split('/'))).pop(), + list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length))) if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: logger.info("Loading features from cached file %s", cached_features_file) @@ -312,8 +312,10 @@ def main(): help="SQuAD json for training. E.g., train-v1.1.json") parser.add_argument("--predict_file", default=None, type=str, required=True, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") - parser.add_argument("--model_name", default=None, type=str, required=True, - help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints and predictions will be written.") @@ -438,15 +440,11 @@ def main(): if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab - args.model_type = "" - for key in MODEL_CLASSES: - if key in args.model_name.lower(): - args.model_type = key # take the first match in model types - break + args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name) - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case) - model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config) + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab diff --git a/examples/test_examples.py b/examples/test_examples.py index a07c0ea31b..00370e9361 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -60,8 +60,9 @@ class ExamplesTests(unittest.TestCase): "--warmup_steps=2", "--overwrite_output_dir", "--seed=42"] - model_name = "--model_name=bert-base-uncased" - with patch.object(sys, 'argv', testargs + [model_name]): + model_type, model_name = ("--model_type=bert", + "--model_name_or_path=bert-base-uncased") + with patch.object(sys, 'argv', testargs + [model_type, model_name]): result = run_glue.main() for value in result.values(): self.assertGreaterEqual(value, 0.75) @@ -85,8 +86,9 @@ class ExamplesTests(unittest.TestCase): "--per_gpu_eval_batch_size=1", "--overwrite_output_dir", "--seed=42"] - model_name = "--model_name=bert-base-uncased" - with patch.object(sys, 'argv', testargs + [model_name]): + model_type, model_name = ("--model_type=bert", + "--model_name_or_path=bert-base-uncased") + with patch.object(sys, 'argv', testargs + [model_type, model_name]): result = run_squad.main() self.assertGreaterEqual(result['f1'], 30) self.assertGreaterEqual(result['exact'], 30) diff --git a/examples/utils_squad.py b/examples/utils_squad.py index d898a0a17e..34a0c9cc02 100644 --- a/examples/utils_squad.py +++ b/examples/utils_squad.py @@ -87,6 +87,7 @@ class InputFeatures(object): segment_ids, cls_index, p_mask, + paragraph_len, start_position=None, end_position=None, is_impossible=None): @@ -101,6 +102,7 @@ class InputFeatures(object): self.segment_ids = segment_ids self.cls_index = cls_index self.p_mask = p_mask + self.paragraph_len = paragraph_len self.start_position = start_position self.end_position = end_position self.is_impossible = is_impossible @@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, tokens.append(all_doc_tokens[split_token_index]) segment_ids.append(sequence_b_segment_id) p_mask.append(0) + paragraph_len = doc_span.length # SEP token tokens.append(sep_token) @@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, segment_ids=segment_ids, cls_index=cls_index, p_mask=p_mask, + paragraph_len=paragraph_len, start_position=start_position, end_position=end_position, is_impossible=span_is_impossible)) @@ -673,8 +677,9 @@ RawResultExtended = collections.namedtuple("RawResultExtended", def write_predictions_extended(all_examples, all_features, all_results, n_best_size, max_answer_length, output_prediction_file, output_nbest_file, - output_null_log_odds_file, orig_data, - start_n_top, end_n_top, version_2_with_negative): + output_null_log_odds_file, orig_data_file, + start_n_top, end_n_top, version_2_with_negative, + tokenizer, verbose_logging): """ XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of null if needed. @@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s break feature = features[pred.feature_index] - tok_start_to_orig_index = feature.tok_start_to_orig_index - tok_end_to_orig_index = feature.tok_end_to_orig_index - start_orig_pos = tok_start_to_orig_index[pred.start_index] - end_orig_pos = tok_end_to_orig_index[pred.end_index] + # XLNet un-tokenizer + # Let's keep it simple for now and see if we need all this later. + # + # tok_start_to_orig_index = feature.tok_start_to_orig_index + # tok_end_to_orig_index = feature.tok_end_to_orig_index + # start_orig_pos = tok_start_to_orig_index[pred.start_index] + # end_orig_pos = tok_end_to_orig_index[pred.end_index] + # paragraph_text = example.paragraph_text + # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() - paragraph_text = example.paragraph_text - final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() + # Previously used Bert untokenizer + tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] + tok_text = tokenizer.convert_tokens_to_string(tok_tokens) + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, + verbose_logging) if final_text in seen_predictions: continue @@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s with open(output_null_log_odds_file, "w") as writer: writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + with open(orig_data_file, "r", encoding='utf-8') as reader: + orig_data = json.load(reader)["data"] + qid_to_has_ans = make_qid_to_has_ans(orig_data) has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index ebee4fac1d..2c15aa740b 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module): Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 1.0 means token should be masked. """ - slen, hsz = hidden_states.shape[-2:] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) @@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample """ - slen, hsz = hidden_states.shape[-2:] + hsz = hidden_states.shape[-1] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" if start_positions is not None: start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) @@ -614,12 +614,21 @@ class SQuADHead(nn.Module): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. - **last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` - Sequence of hidden-states at the last layer of the model. - **mems**: - list of ``torch.FloatTensor`` (one for each layer): - that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model - (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. + **start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + **start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` + Indices for the top config.start_n_top start token possibilities (beam-search). + **end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size,)`` + Log probabilities for the ``is_impossible`` label of the answers. """ def __init__(self, config): super(SQuADHead, self).__init__() @@ -667,8 +676,8 @@ class SQuADHead(nn.Module): start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) - start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) - start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 6de4d02103..848e73cfc9 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): 1.0 means token should be masked. 0.0 mean token is not masked. Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` - Span-start scores (before SoftMax). - **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` - Span-end scores (before SoftMax). + **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. + **start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + **start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` + Indices for the top config.start_n_top start token possibilities (beam-search). + **end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size,)`` + Log probabilities for the ``is_impossible`` label of the answers. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model @@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): loss_fct_cls = nn.BCEWithLogitsLoss() cls_loss = loss_fct_cls(cls_logits, is_impossible) - # note(zhiliny): by default multiply the loss by 0.5 so that the scale is - # comparable to start_loss and end_loss + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss total_loss += cls_loss * 0.5 - outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs - else: - outputs = (total_loss, start_logits, end_logits) + outputs + + outputs = (total_loss,) + outputs else: # during inference, compute the end logits based on beam search @@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) - start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) - start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) @@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) - start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) - cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs - # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions) - # or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions) + # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits + # or (if labels are provided) (total_loss,) return outputs diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index dbbe9ac5ea..0b9cfb1b32 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase): with open(vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname) + input_text = u"UNwant\u00E9d,running" + output_text = u"unwanted, running" + + create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname) tokenizer = BertTokenizer(vocab_file) diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index 8ae8896187..8dae72ec99 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase): with open(merges_file, "w") as fp: fp.write("\n".join(merges)) - create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map) + input_text = u"lower newer" + output_text = u"lowernewer" + + create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map) tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map) text = "lower" diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index f5c99877d7..9b4841a605 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): with open(merges_file, "w") as fp: fp.write("\n".join(merges)) - create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname) + input_text = u"lower newer" + output_text = u"lower newer" + + create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname) tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 44adbc6b53..e33ba3cb06 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) -def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs): +def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) - text = u"He is very happy, UNwant\u00E9d,running" - tokens = tokenizer.tokenize(text) + tokens = tokenizer.tokenize(input_text) ids = tokenizer.convert_tokens_to_ids(tokens) - ids_2 = tokenizer.encode(text) + ids_2 = tokenizer.encode(input_text) tester.assertListEqual(ids, ids_2) tokens_2 = tokenizer.convert_ids_to_tokens(ids) text_2 = tokenizer.decode(ids) + tester.assertEqual(text_2, output_text) + tester.assertNotEqual(len(tokens_2), 0) tester.assertIsInstance(text_2, (str, unicode)) -def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs): - create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs) +def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): + create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs) create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index 135f48b0ef..aecfeaef5f 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase): with open(vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True) + input_text = u" UNwanted , running" + output_text = u" unwanted, running" + + create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True) tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index 827ec1606e..97e8fa983f 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase): with open(merges_file, "w") as fp: fp.write("\n".join(merges)) - create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname) + input_text = u"lower newer" + output_text = u"lower newer" + + create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname) tokenizer = XLMTokenizer(vocab_file, merges_file) diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index e50fe9243d..27c6b984ee 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase): with TemporaryDirectory() as tmpdirname: tokenizer.save_pretrained(tmpdirname) - create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname) + input_text = u"This is a test" + output_text = u"This is a test" + + create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname) tokens = tokenizer.tokenize(u'This is a test') self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 8b34a43e5a..f1e900caaf 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer): """Converts an index (integer) in a token (string/unicode) using the vocab.""" return self.ids_to_tokens.get(index, self.unk_token) - def _convert_ids_to_string(self, tokens_ids): - """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(tokens_ids) - out_string = ''.join(tokens).replace(' ##', '').strip() + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = ' '.join(tokens).replace(' ##', '').strip() return out_string def save_vocabulary(self, vocab_path): diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index bd90a92251..43c57c9cd3 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): """Converts an index (integer) in a token (string/unicode) using the vocab.""" return self.decoder.get(index) - def _convert_ids_to_string(self, tokens_ids): - """Converts a sequence of ids in a string.""" - text = ''.join(tokens_ids) + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + text = ''.join(tokens) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text diff --git a/pytorch_transformers/tokenization_openai.py b/pytorch_transformers/tokenization_openai.py index 16d355c57d..0eb5281d39 100644 --- a/pytorch_transformers/tokenization_openai.py +++ b/pytorch_transformers/tokenization_openai.py @@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): """Converts an id in a token (BPE) using the vocab.""" return self.decoder.get(index, self.unk_token) - def _convert_ids_to_string(self, tokens_ids): - """Converts a sequence of ids in a string.""" - out_string = ''.join(tokens_ids).replace('', ' ').strip() + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = ''.join(tokens).replace('', ' ').strip() return out_string def save_vocabulary(self, save_directory): diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index 98b4eb6ff5..b08e8e1cca 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): else: raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') - def _convert_ids_to_string(self, tokens_ids): - """Converts a sequence of ids in a string.""" - out_string = ' '.join(tokens_ids).strip() + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = ' '.join(tokens).strip() return out_string def convert_to_tensor(self, symbols): diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 9840e75225..d857e6f2d4 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -361,52 +361,26 @@ class PreTrainedTokenizer(object): (resp.) a sequence of ids, using the vocabulary. """ if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): - return self.convert_token_to_id_with_added_voc(tokens) + return self._convert_token_to_id_with_added_voc(tokens) ids = [] for token in tokens: - ids.append(self.convert_token_to_id_with_added_voc(token)) + ids.append(self._convert_token_to_id_with_added_voc(token)) if len(ids) > self.max_len: logger.warning("Token indices sequence length is longer than the specified maximum sequence length " "for this model ({} > {}). Running this sequence through the model will result in " "indexing errors".format(len(ids), self.max_len)) return ids - - def convert_token_to_id_with_added_voc(self, token): + def _convert_token_to_id_with_added_voc(self, token): if token in self.added_tokens_encoder: return self.added_tokens_encoder[token] return self._convert_token_to_id(token) - def _convert_token_to_id(self, token): raise NotImplementedError - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): - """ Converts a single index or a sequence of indices (integers) in a token " - (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. - - Args: - skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False - """ - if isinstance(ids, int): - return self.convert_id_to_token(ids) - tokens = [] - for index in ids: - if index in self.all_special_ids and skip_special_tokens: - continue - if index in self.added_tokens_decoder: - tokens.append(self.added_tokens_decoder[index]) - else: - tokens.append(self._convert_id_to_token(index)) - return tokens - - - def _convert_id_to_token(self, index): - raise NotImplementedError - - def encode(self, text): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. same as self.convert_tokens_to_ids(self.tokenize(text)). @@ -414,22 +388,48 @@ class PreTrainedTokenizer(object): return self.convert_tokens_to_ids(self.tokenize(text)) + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """ Converts a single index or a sequence of indices (integers) in a token " + (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. + + Args: + skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False + """ + if isinstance(ids, int): + if ids in self.added_tokens_decoder: + return self.added_tokens_decoder[ids] + else: + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + if index in self.all_special_ids and skip_special_tokens: + continue + if index in self.added_tokens_decoder: + tokens.append(self.added_tokens_decoder[index]) + else: + tokens.append(self._convert_id_to_token(index)) + return tokens + + def _convert_id_to_token(self, index): + raise NotImplementedError + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. + The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) + but we often want to remove sub-word tokenization artifacts at the same time. + """ + return ' '.join(self.convert_ids_to_tokens(tokens)) + def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary with options to remove special tokens and clean up tokenization spaces. """ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) - text = self._convert_ids_to_string(filtered_tokens) + text = self.convert_tokens_to_string(filtered_tokens) if clean_up_tokenization_spaces: text = clean_up_tokenization(text) return text - def _convert_ids_to_string(self, tokens_ids): - """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary. - roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)). - """ - return ' '.join(self.convert_ids_to_tokens(tokens_ids)) - @property def special_tokens_map(self): """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 58fefa104b..42b61badcd 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer): """Converts an index (integer) in a token (string/unicode) using the vocab.""" return self.decoder.get(index, self.unk_token) - def _convert_ids_to_string(self, tokens_ids): - """Converts a sequence of ids in a string.""" - out_string = ''.join(tokens_ids).replace('', ' ').strip() + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = ''.join(tokens).replace('', ' ').strip() return out_string def save_vocabulary(self, save_directory): diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index d7317b2afc..fa60a18d8a 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer): token = token.decode('utf-8') return token - def _convert_ids_to_string(self, tokens_ids): - """Converts a sequence of ids in a string.""" - out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ') + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() return out_string def save_vocabulary(self, save_directory): @@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer): return out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) - copyfile(self.vocab_file, out_vocab_file) + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,)