Update naming + remove f string in run_lm_finetuning example
This commit is contained in:
@@ -276,7 +276,7 @@ class CommonTestCases:
|
||||
assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
|
||||
assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
|
||||
|
||||
def test_sequence_ids(self):
|
||||
def test_special_tokens_mask(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sequence_0 = "Encode this."
|
||||
@@ -286,10 +286,10 @@ class CommonTestCases:
|
||||
encoded_sequence = tokenizer.encode(sequence_0)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
sequence_ids = encoded_sequence_dict["sequence_ids"]
|
||||
assert len(sequence_ids) == len(encoded_sequence_w_special)
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
assert len(special_tokens_mask) == len(encoded_sequence_w_special)
|
||||
|
||||
filtered_sequence = [(x if sequence_ids[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
assert encoded_sequence == filtered_sequence
|
||||
|
||||
@@ -297,10 +297,10 @@ class CommonTestCases:
|
||||
encoded_sequence = tokenizer.encode(sequence_0) + tokenizer.encode(sequence_1)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
sequence_ids = encoded_sequence_dict["sequence_ids"]
|
||||
assert len(sequence_ids) == len(encoded_sequence_w_special)
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
assert len(special_tokens_mask) == len(encoded_sequence_w_special)
|
||||
|
||||
filtered_sequence = [(x if sequence_ids[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
assert encoded_sequence == filtered_sequence
|
||||
|
||||
@@ -309,10 +309,10 @@ class CommonTestCases:
|
||||
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
|
||||
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
|
||||
assert len(sequence_ids) == len(encoded_sequence_w_special)
|
||||
assert sequence_ids_orig == sequence_ids
|
||||
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
|
||||
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, special_tokens_present=True)
|
||||
assert len(special_tokens_mask) == len(encoded_sequence_w_special)
|
||||
assert special_tokens_mask_orig == special_tokens_mask
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user