Sequence IDS

This commit is contained in:
LysandreJik
2019-09-30 11:48:18 -04:00
parent 7c789c337d
commit 2f259b228e
6 changed files with 121 additions and 1 deletions

View File

@@ -292,3 +292,33 @@ class CommonTestCases:
assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
def test_sequence_ids(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
sequence_ids = encoded_sequence_dict["sequence_ids"]
assert len(sequence_ids) == len(encoded_sequence_w_special)
filtered_sequence = [(x if sequence_ids[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
filtered_sequence = [x for x in filtered_sequence if x is not None]
assert encoded_sequence == filtered_sequence
# Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0) + tokenizer.encode(sequence_1)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
sequence_ids = encoded_sequence_dict["sequence_ids"]
assert len(sequence_ids) == len(encoded_sequence_w_special)
filtered_sequence = [(x if sequence_ids[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
filtered_sequence = [x for x in filtered_sequence if x is not None]
assert encoded_sequence == filtered_sequence