Adding Llama FastTokenizer support. (#22264)
* Adding Llama FastTokenizer support. - Requires https://github.com/huggingface/tokenizers/pull/1183 version - Only support byte_fallback for llama, raise otherwise (safety net). - Lots of questions are special tokens How to test: ```python from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers import AutoTokenizer from tokenizers import Tokenizer tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") if False: new_tokenizer = Tokenizer.from_file("tok.json") else: new_tokenizer = convert_slow_tokenizer(tokenizer) new_tokenizer.save("tok.json") strings = [ "This is a test", "生活的真谛是", "生活的真谛是[MASK]。", # XXX: This one is problematic because of special tokens # "<s> Something something", ] for string in strings: encoded = tokenizer(string)["input_ids"] encoded2 = new_tokenizer.encode(string).ids assert encoded == encoded2, f"{encoded} != {encoded2}" decoded = tokenizer.decode(encoded) decoded2 = new_tokenizer.decode(encoded2) assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}" ``` The converter + some test script. The test script. Tmp save. Adding Fast tokenizer + tests. Adding the tokenization tests. Correct combination. Small fix. Fixing tests. Fixing with latest update. Rebased. fix copies + normalized added tokens + copies. Adding doc. TMP. Doc + split files. Doc. Versions + try import. Fix Camembert + warnings -> Error. Fix by ArthurZucker. Not a decorator. * Fixing comments. * Adding more to docstring. * Doc rewriting.
This commit is contained in:
@@ -336,7 +336,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| LiLT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| LiLT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| LLaMA | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| LLaMA | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ |
|
| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||||
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
|||||||
@@ -59,6 +59,14 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr
|
|||||||
- create_token_type_ids_from_sequences
|
- create_token_type_ids_from_sequences
|
||||||
- save_vocabulary
|
- save_vocabulary
|
||||||
|
|
||||||
|
## LlamaTokenizerFast
|
||||||
|
|
||||||
|
[[autodoc]] LlamaTokenizerFast
|
||||||
|
- build_inputs_with_special_tokens
|
||||||
|
- get_special_tokens_mask
|
||||||
|
- create_token_type_ids_from_sequences
|
||||||
|
- save_vocabulary
|
||||||
|
|
||||||
## LlamaModel
|
## LlamaModel
|
||||||
|
|
||||||
[[autodoc]] LlamaModel
|
[[autodoc]] LlamaModel
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -78,7 +78,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from setuptools import setup, Command
|
from setuptools import Command, setup
|
||||||
|
|
||||||
|
|
||||||
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
||||||
@@ -251,6 +251,7 @@ class DepsTableUpdateCommand(Command):
|
|||||||
with open(target, "w", encoding="utf-8", newline="\n") as f:
|
with open(target, "w", encoding="utf-8", newline="\n") as f:
|
||||||
f.write("\n".join(content))
|
f.write("\n".join(content))
|
||||||
|
|
||||||
|
|
||||||
extras = {}
|
extras = {}
|
||||||
|
|
||||||
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp")
|
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp")
|
||||||
|
|||||||
@@ -740,6 +740,7 @@ else:
|
|||||||
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
|
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
|
||||||
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
|
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
|
||||||
_import_structure["models.led"].append("LEDTokenizerFast")
|
_import_structure["models.led"].append("LEDTokenizerFast")
|
||||||
|
_import_structure["models.llama"].append("LlamaTokenizerFast")
|
||||||
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
||||||
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
||||||
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
|
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
|
||||||
@@ -4388,6 +4389,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.layoutlmv3 import LayoutLMv3TokenizerFast
|
from .models.layoutlmv3 import LayoutLMv3TokenizerFast
|
||||||
from .models.layoutxlm import LayoutXLMTokenizerFast
|
from .models.layoutxlm import LayoutXLMTokenizerFast
|
||||||
from .models.led import LEDTokenizerFast
|
from .models.led import LEDTokenizerFast
|
||||||
|
from .models.llama import LlamaTokenizerFast
|
||||||
from .models.longformer import LongformerTokenizerFast
|
from .models.longformer import LongformerTokenizerFast
|
||||||
from .models.lxmert import LxmertTokenizerFast
|
from .models.lxmert import LxmertTokenizerFast
|
||||||
from .models.markuplm import MarkupLMTokenizerFast
|
from .models.markuplm import MarkupLMTokenizerFast
|
||||||
|
|||||||
@@ -19,10 +19,9 @@ All the conversions are grouped here to gather SentencePiece dependencies outsid
|
|||||||
allow to make our dependency on SentencePiece optional.
|
allow to make our dependency on SentencePiece optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||||
from tokenizers.models import BPE, Unigram, WordPiece
|
from tokenizers.models import BPE, Unigram, WordPiece
|
||||||
|
|
||||||
from .utils import requires_backends
|
from .utils import requires_backends
|
||||||
@@ -450,12 +449,13 @@ class SpmConverter(Converter):
|
|||||||
self.proto = m
|
self.proto = m
|
||||||
|
|
||||||
if self.proto.trainer_spec.byte_fallback:
|
if self.proto.trainer_spec.byte_fallback:
|
||||||
warnings.warn(
|
if not getattr(self, "handle_byte_fallback", None):
|
||||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
raise RuntimeError(
|
||||||
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||||
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||||
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||||
)
|
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||||
|
)
|
||||||
|
|
||||||
def vocab(self, proto):
|
def vocab(self, proto):
|
||||||
return [(piece.piece, piece.score) for piece in proto.pieces]
|
return [(piece.piece, piece.score) for piece in proto.pieces]
|
||||||
@@ -1094,6 +1094,78 @@ class XGLMConverter(SpmConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaConverter(SpmConverter):
|
||||||
|
handle_byte_fallback = True
|
||||||
|
|
||||||
|
def vocab(self, proto):
|
||||||
|
vocab = [
|
||||||
|
("<unk>", 0.0),
|
||||||
|
("<s>", 0.0),
|
||||||
|
("</s>", 0.0),
|
||||||
|
]
|
||||||
|
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def unk_id(self, proto):
|
||||||
|
unk_id = 0
|
||||||
|
return unk_id
|
||||||
|
|
||||||
|
def decoder(self, replacement, add_prefix_space):
|
||||||
|
return decoders.Sequence(
|
||||||
|
[
|
||||||
|
decoders.Replace("▁", " "),
|
||||||
|
decoders.ByteFallback(),
|
||||||
|
decoders.Fuse(),
|
||||||
|
decoders.Strip(content=" ", left=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenizer(self, proto):
|
||||||
|
model_type = proto.trainer_spec.model_type
|
||||||
|
vocab_scores = self.vocab(proto)
|
||||||
|
if model_type == 1:
|
||||||
|
raise RuntimeError("Llama is supposed to be a BPE model!")
|
||||||
|
elif model_type == 2:
|
||||||
|
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
||||||
|
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
||||||
|
tokenizer = Tokenizer(
|
||||||
|
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
|
||||||
|
)
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
[
|
||||||
|
AddedToken("<unk>", normalized=True),
|
||||||
|
AddedToken("<s>", normalized=True),
|
||||||
|
AddedToken("</s>", normalized=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def normalizer(self, proto):
|
||||||
|
return normalizers.Sequence(
|
||||||
|
[
|
||||||
|
normalizers.Prepend(prepend="▁"),
|
||||||
|
normalizers.Replace(pattern=" ", content="▁"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def post_processor(self):
|
||||||
|
return processors.TemplateProcessing(
|
||||||
|
single="<s> $A",
|
||||||
|
pair="<s> $A $B",
|
||||||
|
special_tokens=[
|
||||||
|
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MarkupLMConverter(Converter):
|
class MarkupLMConverter(Converter):
|
||||||
def converted(self) -> Tokenizer:
|
def converted(self) -> Tokenizer:
|
||||||
ot = self.original_tokenizer
|
ot = self.original_tokenizer
|
||||||
@@ -1183,6 +1255,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
|||||||
"XLNetTokenizer": XLNetConverter,
|
"XLNetTokenizer": XLNetConverter,
|
||||||
"SplinterTokenizer": SplinterConverter,
|
"SplinterTokenizer": SplinterConverter,
|
||||||
"XGLMTokenizer": XGLMConverter,
|
"XGLMTokenizer": XGLMConverter,
|
||||||
|
"LlamaTokenizer": LlamaConverter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,13 @@ else:
|
|||||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("llama", ("LlamaTokenizer" if is_sentencepiece_available() else None, None)),
|
(
|
||||||
|
"llama",
|
||||||
|
(
|
||||||
|
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||||
|
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
|
),
|
||||||
|
),
|
||||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
(
|
(
|
||||||
"longt5",
|
"longt5",
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from ...utils import (
|
|||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,6 +34,14 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["tokenization_llama"] = ["LlamaTokenizer"]
|
_import_structure["tokenization_llama"] = ["LlamaTokenizer"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_tokenizers_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@@ -58,6 +67,14 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .tokenization_llama import LlamaTokenizer
|
from .tokenization_llama import LlamaTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_tokenizers_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .tokenization_llama_fast import LlamaTokenizerFast
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|||||||
82
src/transformers/models/llama/tokenization_llama_fast.py
Normal file
82
src/transformers/models/llama/tokenization_llama_fast.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
from ...utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
require_version("tokenizers>=0.13.3")
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||||
|
"""
|
||||||
|
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||||
|
|
||||||
|
This uses notably ByteFallback and no normalization.
|
||||||
|
|
||||||
|
```
|
||||||
|
from transformers import LlamaTokenizerFast
|
||||||
|
|
||||||
|
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
tokenizer.encode("Hello this is a test")
|
||||||
|
>>> [1, 15043, 445, 338, 263, 1243]
|
||||||
|
```
|
||||||
|
|
||||||
|
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||||
|
refer to this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`):
|
||||||
|
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
||||||
|
contains the vocabulary necessary to instantiate a tokenizer.
|
||||||
|
tokenizer_file (`str`):
|
||||||
|
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
||||||
|
contains everything needed to load the tokenizer.
|
||||||
|
|
||||||
|
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
||||||
|
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
||||||
|
spaces.
|
||||||
|
|
||||||
|
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||||
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||||
|
|
||||||
|
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
|
||||||
|
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
padding_side = "left"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file=None,
|
||||||
|
tokenizer_file=None,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
unk_token="<unk>",
|
||||||
|
bos_token="<s>",
|
||||||
|
eos_token="</s>",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
unk_token=unk_token,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@@ -219,6 +219,13 @@ class LEDTokenizerFast(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tokenizers"])
|
requires_backends(self, ["tokenizers"])
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaTokenizerFast(metaclass=DummyObject):
|
||||||
|
_backends = ["tokenizers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tokenizers"])
|
||||||
|
|
||||||
|
|
||||||
class LongformerTokenizerFast(metaclass=DummyObject):
|
class LongformerTokenizerFast(metaclass=DummyObject):
|
||||||
_backends = ["tokenizers"]
|
_backends = ["tokenizers"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# coding=utf-8
|
||||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -23,8 +24,10 @@ from transformers import (
|
|||||||
SPIECE_UNDERLINE,
|
SPIECE_UNDERLINE,
|
||||||
AddedToken,
|
AddedToken,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
|
LlamaTokenizerFast,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
@@ -287,13 +290,11 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class LlamaIntegrationTest(unittest.TestCase):
|
class LlamaIntegrationTest(unittest.TestCase):
|
||||||
checkpoint_name = "hf-internal-testing/llama-tokenizer"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(cls.checkpoint_name)
|
checkpoint_name = "hf-internal-testing/llama-tokenizer"
|
||||||
cls.rust_tokenizer = cls.tokenizer # TODO @narsil replace with the rust one
|
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(checkpoint_name)
|
||||||
cls.pad_token_id = 1
|
cls.rust_tokenizer = LlamaTokenizerFast.from_pretrained(checkpoint_name)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -314,6 +315,27 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_conversion(self):
|
||||||
|
# This is excruciatingly slow since it has to recreate the entire merge
|
||||||
|
# list from the original vocabulary in spm
|
||||||
|
self.rust_tokenizer.save_pretrained("./out")
|
||||||
|
with tempfile.TemporaryDirectory() as dirname:
|
||||||
|
self.rust_tokenizer.save_pretrained(dirname)
|
||||||
|
|
||||||
|
with open(os.path.join(dirname, "tokenizer.json"), "r") as f:
|
||||||
|
old_serialized = f.read()
|
||||||
|
|
||||||
|
new_tokenizer = convert_slow_tokenizer(self.tokenizer)
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
new_tokenizer.save(f.name)
|
||||||
|
# Re-opening since `f` is in bytes.
|
||||||
|
new_serialized = open(f.name, "r").read()
|
||||||
|
with open("out_tokenizer.json", "w") as g:
|
||||||
|
g.write(new_serialized)
|
||||||
|
|
||||||
|
self.assertEqual(old_serialized, new_serialized)
|
||||||
|
|
||||||
def test_simple_encode_decode(self):
|
def test_simple_encode_decode(self):
|
||||||
pyth_tokenizer = self.tokenizer
|
pyth_tokenizer = self.tokenizer
|
||||||
rust_tokenizer = self.rust_tokenizer
|
rust_tokenizer = self.rust_tokenizer
|
||||||
@@ -362,11 +384,27 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||||
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||||
|
|
||||||
|
def test_no_differences_showcase(self):
|
||||||
|
pyth_tokenizer = self.tokenizer
|
||||||
|
rust_tokenizer = self.rust_tokenizer
|
||||||
|
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||||
|
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||||
|
|
||||||
|
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
|
||||||
|
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
|
||||||
|
|
||||||
|
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
|
||||||
|
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
|
||||||
|
|
||||||
|
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||||
|
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||||
|
|
||||||
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
||||||
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
||||||
|
|
||||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
def test_no_differences_decode(self):
|
||||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
pyth_tokenizer = self.tokenizer
|
||||||
|
rust_tokenizer = self.rust_tokenizer
|
||||||
|
|
||||||
self.assertEqual(pyth_tokenizer.decode([869]), ".")
|
self.assertEqual(pyth_tokenizer.decode([869]), ".")
|
||||||
self.assertEqual(rust_tokenizer.decode([869]), ".")
|
self.assertEqual(rust_tokenizer.decode([869]), ".")
|
||||||
@@ -374,6 +412,15 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
|
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
|
||||||
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
|
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
|
||||||
|
|
||||||
|
def test_no_differences_special_tokens(self):
|
||||||
|
pyth_tokenizer = self.tokenizer
|
||||||
|
rust_tokenizer = self.rust_tokenizer
|
||||||
|
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||||
|
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||||
|
|
||||||
|
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
||||||
|
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
|
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
|
||||||
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
|
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
|
||||||
@@ -392,8 +439,8 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(encoded1, encoded2)
|
self.assertEqual(encoded1, encoded2)
|
||||||
|
|
||||||
decoded1 = pyth_tokenizer.decode(encoded1)
|
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||||
decoded2 = rust_tokenizer.decode(encoded2)
|
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(decoded1, decoded2)
|
self.assertEqual(decoded1, decoded2)
|
||||||
|
|
||||||
@@ -406,7 +453,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(encoded1, encoded2)
|
self.assertEqual(encoded1, encoded2)
|
||||||
|
|
||||||
decoded1 = pyth_tokenizer.decode(encoded1)
|
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||||
decoded2 = rust_tokenizer.decode(encoded2)
|
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(decoded1, decoded2)
|
self.assertEqual(decoded1, decoded2)
|
||||||
|
|||||||
@@ -24,11 +24,10 @@ class ConvertSlowTokenizerTest(unittest.TestCase):
|
|||||||
|
|
||||||
original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
|
original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
_ = SpmConverter(original_tokenizer_with_bytefallback)
|
_ = SpmConverter(original_tokenizer_with_bytefallback)
|
||||||
self.assertEqual(len(w), 1)
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||||
" which is not implemented in the fast tokenizers.",
|
" which is not implemented in the fast tokenizers.",
|
||||||
str(w[0].message),
|
str(cm.exception),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user