Update: add type hints to check_tokenizers.py (#40094)
* Update check_tokenizers.py chore(typing): add type hints to check_tokenizers script - Annotate params/returns for helper functions - Keep tokenizer instances as `Any` to avoid runtime coupling - Make `check_LTR_mark` return `bool` explicitly (no behavior change) * Update check_tokenizers.py chore(typing): replace Any with PreTrainedTokenizerBase in check_tokenizers.py - Use transformers.tokenization_utils_base.PreTrainedTokenizerBase for `slow` and `fast` params - Covers both PreTrainedTokenizer and PreTrainedTokenizerFast - Exposes required methods (encode, decode, encode_plus, tokenize) - Removes generic Any typing while staying implementation-agnostic
This commit is contained in:
@@ -5,6 +5,7 @@ import datasets
|
||||
import transformers
|
||||
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
||||
from transformers.utils import logging
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
@@ -21,7 +22,7 @@ imperfect = 0
|
||||
wrong = 0
|
||||
|
||||
|
||||
def check_diff(spm_diff, tok_diff, slow, fast):
|
||||
def check_diff(spm_diff: list[int], tok_diff: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> bool:
|
||||
if spm_diff == list(reversed(tok_diff)):
|
||||
# AAA -> AA+A vs A+AA case.
|
||||
return True
|
||||
@@ -42,7 +43,7 @@ def check_diff(spm_diff, tok_diff, slow, fast):
|
||||
return False
|
||||
|
||||
|
||||
def check_LTR_mark(line, idx, fast):
|
||||
def check_LTR_mark(line: str, idx: int, fast: PreTrainedTokenizerBase) -> bool:
|
||||
enc = fast.encode_plus(line)[0]
|
||||
offsets = enc.offsets
|
||||
curr, prev = offsets[idx], offsets[idx - 1]
|
||||
@@ -50,9 +51,10 @@ def check_LTR_mark(line, idx, fast):
|
||||
return True
|
||||
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_details(line, spm_ids, tok_ids, slow, fast):
|
||||
def check_details(line: str, spm_ids: list[int], tok_ids: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> bool:
|
||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||
# We can check that we use at least exactly the same number of tokens.
|
||||
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
||||
@@ -111,7 +113,7 @@ def check_details(line, spm_ids, tok_ids, slow, fast):
|
||||
return False
|
||||
|
||||
|
||||
def test_string(slow, fast, text):
|
||||
def test_string(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase, text: str) -> None:
|
||||
global perfect
|
||||
global imperfect
|
||||
global wrong
|
||||
@@ -143,7 +145,7 @@ def test_string(slow, fast, text):
|
||||
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
|
||||
|
||||
|
||||
def test_tokenizer(slow, fast):
|
||||
def test_tokenizer(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> None:
|
||||
global batch_total
|
||||
for i in range(len(dataset)):
|
||||
# premise, all languages
|
||||
|
||||
Reference in New Issue
Block a user