Documentation about loading a fast tokenizer within Transformers (#11029)
* Documentation about loading a fast tokenizer within Transformers * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -12,18 +12,33 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
|
||||
from transformers import (
|
||||
BatchEncoding,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
TensorType,
|
||||
TokenSpan,
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.testing_utils import CaptureStderr, require_flax, require_tf, require_tokenizers, require_torch, slow
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
|
||||
|
||||
class TokenizerUtilsTest(unittest.TestCase):
|
||||
def check_tokenizer_from_pretrained(self, tokenizer_class):
|
||||
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
|
||||
@@ -253,3 +268,15 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
batch = tokenizer.pad(features, padding=True, return_tensors="tf")
|
||||
self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
|
||||
self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
|
||||
|
||||
@require_tokenizers
|
||||
def test_instantiation_from_tokenizers(self):
|
||||
bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
PreTrainedTokenizerFast(tokenizer_object=bert_tokenizer)
|
||||
|
||||
@require_tokenizers
|
||||
def test_instantiation_from_tokenizers_json_file(self):
|
||||
bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
bert_tokenizer.save(os.path.join(tmpdirname, "tokenizer.json"))
|
||||
PreTrainedTokenizerFast(tokenizer_file=os.path.join(tmpdirname, "tokenizer.json"))
|
||||
|
||||
Reference in New Issue
Block a user