Update: add type hints to check_tokenizers.py (#40094)
* Update check_tokenizers.py chore(typing): add type hints to check_tokenizers script - Annotate params/returns for helper functions - Keep tokenizer instances as `Any` to avoid runtime coupling - Make `check_LTR_mark` return `bool` explicitly (no behavior change) * Update check_tokenizers.py chore(typing): replace Any with PreTrainedTokenizerBase in check_tokenizers.py - Use transformers.tokenization_utils_base.PreTrainedTokenizerBase for `slow` and `fast` params - Covers both PreTrainedTokenizer and PreTrainedTokenizerFast - Exposes required methods (encode, decode, encode_plus, tokenize) - Removes generic Any typing while staying implementation-agnostic
This commit is contained in:
@@ -5,6 +5,7 @@ import datasets
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
@@ -21,7 +22,7 @@ imperfect = 0
|
|||||||
wrong = 0
|
wrong = 0
|
||||||
|
|
||||||
|
|
||||||
def check_diff(spm_diff, tok_diff, slow, fast):
|
def check_diff(spm_diff: list[int], tok_diff: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> bool:
|
||||||
if spm_diff == list(reversed(tok_diff)):
|
if spm_diff == list(reversed(tok_diff)):
|
||||||
# AAA -> AA+A vs A+AA case.
|
# AAA -> AA+A vs A+AA case.
|
||||||
return True
|
return True
|
||||||
@@ -42,7 +43,7 @@ def check_diff(spm_diff, tok_diff, slow, fast):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_LTR_mark(line, idx, fast):
|
def check_LTR_mark(line: str, idx: int, fast: PreTrainedTokenizerBase) -> bool:
|
||||||
enc = fast.encode_plus(line)[0]
|
enc = fast.encode_plus(line)[0]
|
||||||
offsets = enc.offsets
|
offsets = enc.offsets
|
||||||
curr, prev = offsets[idx], offsets[idx - 1]
|
curr, prev = offsets[idx], offsets[idx - 1]
|
||||||
@@ -50,9 +51,10 @@ def check_LTR_mark(line, idx, fast):
|
|||||||
return True
|
return True
|
||||||
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
|
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
|
||||||
return True
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_details(line, spm_ids, tok_ids, slow, fast):
|
def check_details(line: str, spm_ids: list[int], tok_ids: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> bool:
|
||||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||||
# We can check that we use at least exactly the same number of tokens.
|
# We can check that we use at least exactly the same number of tokens.
|
||||||
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
||||||
@@ -111,7 +113,7 @@ def check_details(line, spm_ids, tok_ids, slow, fast):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def test_string(slow, fast, text):
|
def test_string(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase, text: str) -> None:
|
||||||
global perfect
|
global perfect
|
||||||
global imperfect
|
global imperfect
|
||||||
global wrong
|
global wrong
|
||||||
@@ -143,7 +145,7 @@ def test_string(slow, fast, text):
|
|||||||
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
|
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
|
||||||
|
|
||||||
|
|
||||||
def test_tokenizer(slow, fast):
|
def test_tokenizer(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> None:
|
||||||
global batch_total
|
global batch_total
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
# premise, all languages
|
# premise, all languages
|
||||||
|
|||||||
Reference in New Issue
Block a user