Update: add type hints to check_tokenizers.py (#40094)

* Update check_tokenizers.py

chore(typing): add type hints to check_tokenizers script

- Annotate params/returns for helper functions
- Keep tokenizer instances as `Any` to avoid runtime coupling
- Make `check_LTR_mark` return `bool` explicitly (no behavior change)

* Update check_tokenizers.py

chore(typing): replace Any with PreTrainedTokenizerBase in check_tokenizers.py

- Use transformers.tokenization_utils_base.PreTrainedTokenizerBase for `slow` and `fast` params
- Covers both PreTrainedTokenizer and PreTrainedTokenizerFast
- Exposes required methods (encode, decode, encode_plus, tokenize)
- Removes generic Any typing while staying implementation-agnostic
This commit is contained in:
Ajeet Verma
2025-08-15 19:41:28 +07:00
committed by GitHub
parent 28a03fb78a
commit de437d0d7a

View File

@@ -5,6 +5,7 @@ import datasets
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
logging.set_verbosity_info()
@@ -21,7 +22,7 @@ imperfect = 0
wrong = 0
def check_diff(spm_diff, tok_diff, slow, fast):
def check_diff(spm_diff: list[int], tok_diff: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> bool:
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
@@ -42,7 +43,7 @@ def check_diff(spm_diff, tok_diff, slow, fast):
return False
def check_LTR_mark(line, idx, fast):
def check_LTR_mark(line: str, idx: int, fast: PreTrainedTokenizerBase) -> bool:
enc = fast.encode_plus(line)[0]
offsets = enc.offsets
curr, prev = offsets[idx], offsets[idx - 1]
@@ -50,9 +51,10 @@ def check_LTR_mark(line, idx, fast):
return True
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
return True
return False
def check_details(line, spm_ids, tok_ids, slow, fast):
def check_details(line: str, spm_ids: list[int], tok_ids: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> bool:
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
@@ -111,7 +113,7 @@ def check_details(line, spm_ids, tok_ids, slow, fast):
return False
def test_string(slow, fast, text):
def test_string(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase, text: str) -> None:
global perfect
global imperfect
global wrong
@@ -143,7 +145,7 @@ def test_string(slow, fast, text):
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
def test_tokenizer(slow, fast):
def test_tokenizer(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> None:
global batch_total
for i in range(len(dataset)):
# premise, all languages