Updated tests
This commit is contained in:
@@ -264,9 +264,14 @@ class CommonTestCases:
|
|||||||
assert len(truncated_sequence) == len(sequence) - 2
|
assert len(truncated_sequence) == len(sequence) - 2
|
||||||
assert truncated_sequence == truncated_second_sequence
|
assert truncated_sequence == truncated_second_sequence
|
||||||
|
|
||||||
def test_tokens_sent_to_encode(self):
|
def test_encode_input_type(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
sequence = "Let's encode this sequence"
|
sequence = "Let's encode this sequence"
|
||||||
tokens = tokenizer.encode(sequence)
|
|
||||||
tokenizer.encode(tokens, add_special_tokens=True)
|
tokens = tokenizer.tokenize(sequence)
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
|
||||||
|
|
||||||
|
assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
|
||||||
|
assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
|
||||||
|
|||||||
@@ -744,7 +744,7 @@ class PreTrainedTokenizer(object):
|
|||||||
def get_input_ids(text):
|
def get_input_ids(text):
|
||||||
if isinstance(text, six.string_types):
|
if isinstance(text, six.string_types):
|
||||||
input_ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
input_ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types):
|
||||||
input_ids = self.convert_tokens_to_ids(text)
|
input_ids = self.convert_tokens_to_ids(text)
|
||||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
||||||
input_ids = text
|
input_ids = text
|
||||||
|
|||||||
Reference in New Issue
Block a user