From c832f43a4dc01fbf55fa63eda1554912f01ac57a Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 24 Sep 2019 07:21:38 -0400 Subject: [PATCH] `output_token_type` -> `token_type_ids` --- examples/utils_glue.py | 2 +- pytorch_transformers/tests/tokenization_tests_commons.py | 2 +- pytorch_transformers/tokenization_utils.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/utils_glue.py b/examples/utils_glue.py index 225b496029..2557540cc6 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -413,7 +413,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, max_length=max_seq_length, truncate_first_sequence=True # We're truncating the first sequence as a priority ) - input_ids, segment_ids = inputs["input_ids"], inputs["output_token_type"] + input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"] # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 323e558310..4ad92c8192 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -197,7 +197,7 @@ class CommonTestCases: seq_0 = "Test this method." seq_1 = "With these inputs." information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_token_type=True) - sequences, mask = information["input_ids"], information["output_token_type"] + sequences, mask = information["input_ids"], information["token_type_ids"] assert len(sequences) == len(mask) def test_number_of_added_tokens(self): diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 478ba6da87..02b4bef699 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -765,7 +765,7 @@ class PreTrainedTokenizer(object): information["input_ids"] = sequence_tokens if output_token_type: - information["output_token_type"] = [0] * len(information["input_ids"]) + information["token_type_ids"] = [0] * len(information["input_ids"]) else: first_sentence_tokens = get_input_ids(text) second_sentence_tokens = get_input_ids(text_pair) @@ -780,7 +780,7 @@ class PreTrainedTokenizer(object): ) if output_token_type: - information["output_token_type"] = self.create_mask_from_sequences(text, text_pair) + information["token_type_ids"] = self.create_mask_from_sequences(text, text_pair) else: logger.warning("No special tokens were added. The two sequences have been concatenated.") sequence = first_sentence_tokens + second_sentence_tokens @@ -789,7 +789,7 @@ class PreTrainedTokenizer(object): information["overflowing_tokens"] = sequence[max_length - stride:] sequence = sequence[:max_length] if output_token_type: - information["output_token_type"] = [0] * len(sequence) + information["token_type_ids"] = [0] * len(sequence) information["input_ids"] = sequence