Updated tests

This commit is contained in:
LysandreJik
2019-09-24 07:10:09 -04:00
parent 9d44236f70
commit 0ea82b246f
2 changed files with 9 additions and 4 deletions

View File

@@ -744,7 +744,7 @@ class PreTrainedTokenizer(object):
def get_input_ids(text):
if isinstance(text, six.string_types):
input_ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types):
input_ids = self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
input_ids = text