cleaning up tokenizer tests structure (at last) - last remaining ppb refs
This commit is contained in:
11
README.md
11
README.md
@@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
|
|||||||
|
|
||||||
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
|
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
|
||||||
|
|
||||||
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer.
|
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
|
||||||
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
|
|
||||||
|
- it only implements weights decay correction,
|
||||||
|
- schedules are now externals (see below),
|
||||||
|
- gradient clipping is now also external (see below).
|
||||||
|
|
||||||
|
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
|
||||||
|
|
||||||
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
|
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
|
||||||
|
|
||||||
@@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
|
|||||||
```python
|
```python
|
||||||
# Parameters:
|
# Parameters:
|
||||||
lr = 1e-3
|
lr = 1e-3
|
||||||
|
max_grad_norm = 1.0
|
||||||
num_total_steps = 1000
|
num_total_steps = 1000
|
||||||
num_warmup_steps = 100
|
num_warmup_steps = 100
|
||||||
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
|
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
|
||||||
@@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
|
|||||||
for batch in train_data:
|
for batch in train_data:
|
||||||
loss = model(batch)
|
loss = model(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
|
|||||||
|
|
||||||
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
|
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
|
||||||
|
|
||||||
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer.
|
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
|
||||||
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
|
|
||||||
|
- it only implements weights decay correction,
|
||||||
|
- schedules are now externals (see below),
|
||||||
|
- gradient clipping is now also external (see below).
|
||||||
|
|
||||||
|
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
|
||||||
|
|
||||||
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
|
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
|
||||||
|
|
||||||
@@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
|
|||||||
```python
|
```python
|
||||||
# Parameters:
|
# Parameters:
|
||||||
lr = 1e-3
|
lr = 1e-3
|
||||||
|
max_grad_norm = 1.0
|
||||||
num_total_steps = 1000
|
num_total_steps = 1000
|
||||||
num_warmup_steps = 100
|
num_warmup_steps = 100
|
||||||
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
|
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
|
||||||
@@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
|
|||||||
for batch in train_data:
|
for batch in train_data:
|
||||||
loss = model(batch)
|
loss = model(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||||
|
|
||||||
output_dir = "./models/"
|
output_dir = "./models/"
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig
|
from pytorch_transformers import BertModel, BertTokenizer, BertConfig
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
enc = BertTokenizer.from_pretrained("bert-base-uncased")
|
enc = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
@@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
|
|||||||
# The model needs to be in evaluation mode
|
# The model needs to be in evaluation mode
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
|
||||||
|
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
|
||||||
|
|
||||||
# Creating the trace
|
# Creating the trace
|
||||||
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
||||||
torch.jit.save(traced_model, "traced_bert.pt")
|
torch.jit.save(traced_model, "traced_bert.pt")
|
||||||
|
|||||||
@@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
|||||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||||
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||||
|
|
||||||
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import argparse
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from pytorch_pretrained_bert.modeling import BertModel
|
from pytorch_transformers.modeling import BertModel
|
||||||
|
|
||||||
|
|
||||||
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
||||||
|
|||||||
@@ -38,10 +38,13 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
|
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
|
||||||
except (AttributeError, ImportError):
|
except (AttributeError, ImportError):
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
|
||||||
default_cache_path)
|
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||||
|
default_cache_path))
|
||||||
|
|
||||||
|
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None):
|
|||||||
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
|
|
||||||
@@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None):
|
|||||||
make sure the file exists and then return the path.
|
make sure the file exists and then return the path.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||||
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
||||||
url_or_filename = str(url_or_filename)
|
url_or_filename = str(url_or_filename)
|
||||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||||
@@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
If it's not there, download it. Then return the path to the cached file.
|
If it's not there, download it. Then return the path to the cached file.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
|
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
|
||||||
|
|||||||
@@ -24,26 +24,33 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
|||||||
_is_control, _is_punctuation,
|
_is_control, _is_punctuation,
|
||||||
_is_whitespace, VOCAB_FILES_NAMES)
|
_is_whitespace, VOCAB_FILES_NAMES)
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
class TokenizationTest(unittest.TestCase):
|
class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
|
tokenizer_class = BertTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(BertTokenizationTest, self).setUp()
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing", ",", "low", "lowest",
|
"##ing", ",", "low", "lowest",
|
||||||
]
|
]
|
||||||
with TemporaryDirectory() as tmpdirname:
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return BertTokenizer.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
input_text = u"UNwant\u00E9d,running"
|
input_text = u"UNwant\u00E9d,running"
|
||||||
output_text = u"unwanted, running"
|
output_text = u"unwanted, running"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
|
|||||||
@@ -20,33 +20,40 @@ import json
|
|||||||
|
|
||||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
class GPT2TokenizationTest(unittest.TestCase):
|
class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
tokenizer_class = GPT2Tokenizer
|
||||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
|
||||||
|
def setUp(self):
|
||||||
|
super(GPT2TokenizationTest, self).setUp()
|
||||||
|
|
||||||
|
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"lo", "low", "er",
|
"lo", "low", "er",
|
||||||
"low", "lowest", "newer", "wider", "<unk>"]
|
"low", "lowest", "newer", "wider", "<unk>"]
|
||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||||
special_tokens_map = {"unk_token": "<unk>"}
|
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||||
|
|
||||||
with TemporaryDirectory() as tmpdirname:
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
with open(self.vocab_file, "w") as fp:
|
||||||
with open(vocab_file, "w") as fp:
|
|
||||||
fp.write(json.dumps(vocab_tokens))
|
fp.write(json.dumps(vocab_tokens))
|
||||||
with open(merges_file, "w") as fp:
|
with open(self.merges_file, "w") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
input_text = u"lower newer"
|
input_text = u"lower newer"
|
||||||
output_text = u"lower<unk>newer"
|
output_text = u"lower<unk>newer"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
|
||||||
text = "lower"
|
text = "lower"
|
||||||
bpe_tokens = ["low", "er"]
|
bpe_tokens = ["low", "er"]
|
||||||
tokens = tokenizer.tokenize(text)
|
tokens = tokenizer.tokenize(text)
|
||||||
|
|||||||
@@ -20,13 +20,17 @@ import json
|
|||||||
|
|
||||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTTokenizationTest(unittest.TestCase):
|
class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
tokenizer_class = OpenAIGPTTokenizer
|
||||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
|
||||||
|
def setUp(self):
|
||||||
|
super(OpenAIGPTTokenizationTest, self).setUp()
|
||||||
|
|
||||||
|
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"w</w>", "r</w>", "t</w>",
|
"w</w>", "r</w>", "t</w>",
|
||||||
"lo", "low", "er</w>",
|
"lo", "low", "er</w>",
|
||||||
@@ -34,20 +38,24 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||||
|
|
||||||
with TemporaryDirectory() as tmpdirname:
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
with open(self.vocab_file, "w") as fp:
|
||||||
with open(vocab_file, "w") as fp:
|
|
||||||
fp.write(json.dumps(vocab_tokens))
|
fp.write(json.dumps(vocab_tokens))
|
||||||
with open(merges_file, "w") as fp:
|
with open(self.merges_file, "w") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
input_text = u"lower newer"
|
input_text = u"lower newer"
|
||||||
output_text = u"lower newer"
|
output_text = u"lower newer"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
|
|
||||||
|
|
||||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
|
||||||
|
|
||||||
text = "lower"
|
text = "lower"
|
||||||
bpe_tokens = ["low", "er</w>"]
|
bpe_tokens = ["low", "er</w>"]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import sys
|
|||||||
from io import open
|
from io import open
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
|
import unittest
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
@@ -36,8 +37,26 @@ else:
|
|||||||
unicode = str
|
unicode = str
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
class CommonTestCases:
|
||||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
|
||||||
|
class CommonTokenizerTester(unittest.TestCase):
|
||||||
|
|
||||||
|
tokenizer_class = None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_save_and_load_tokenizer(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||||
|
|
||||||
@@ -46,11 +65,11 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, *
|
|||||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||||
tester.assertListEqual(before_tokens, after_tokens)
|
self.assertListEqual(before_tokens, after_tokens)
|
||||||
|
|
||||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def test_pickle_tokenizer(self):
|
||||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
tokenizer = self.get_tokenizer()
|
||||||
tester.assertIsNotNone(tokenizer)
|
self.assertIsNotNone(tokenizer)
|
||||||
|
|
||||||
text = u"Munich and Berlin are nice cities"
|
text = u"Munich and Berlin are nice cities"
|
||||||
subwords = tokenizer.tokenize(text)
|
subwords = tokenizer.tokenize(text)
|
||||||
@@ -64,32 +83,32 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
|
|||||||
|
|
||||||
subwords_loaded = tokenizer_new.tokenize(text)
|
subwords_loaded = tokenizer_new.tokenize(text)
|
||||||
|
|
||||||
tester.assertListEqual(subwords, subwords_loaded)
|
self.assertListEqual(subwords, subwords_loaded)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def test_add_tokens_tokenizer(self):
|
||||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
all_size = len(tokenizer)
|
all_size = len(tokenizer)
|
||||||
|
|
||||||
tester.assertNotEqual(vocab_size, 0)
|
self.assertNotEqual(vocab_size, 0)
|
||||||
tester.assertEqual(vocab_size, all_size)
|
self.assertEqual(vocab_size, all_size)
|
||||||
|
|
||||||
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
||||||
added_toks = tokenizer.add_tokens(new_toks)
|
added_toks = tokenizer.add_tokens(new_toks)
|
||||||
vocab_size_2 = tokenizer.vocab_size
|
vocab_size_2 = tokenizer.vocab_size
|
||||||
all_size_2 = len(tokenizer)
|
all_size_2 = len(tokenizer)
|
||||||
|
|
||||||
tester.assertNotEqual(vocab_size_2, 0)
|
self.assertNotEqual(vocab_size_2, 0)
|
||||||
tester.assertEqual(vocab_size, vocab_size_2)
|
self.assertEqual(vocab_size, vocab_size_2)
|
||||||
tester.assertEqual(added_toks, len(new_toks))
|
self.assertEqual(added_toks, len(new_toks))
|
||||||
tester.assertEqual(all_size_2, all_size + len(new_toks))
|
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||||
|
|
||||||
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
||||||
tester.assertGreaterEqual(len(tokens), 4)
|
self.assertGreaterEqual(len(tokens), 4)
|
||||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
|
|
||||||
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||||
'pad_token': "<<<<<|||>|>>>>|>"}
|
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||||
@@ -97,52 +116,45 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
|
|||||||
vocab_size_3 = tokenizer.vocab_size
|
vocab_size_3 = tokenizer.vocab_size
|
||||||
all_size_3 = len(tokenizer)
|
all_size_3 = len(tokenizer)
|
||||||
|
|
||||||
tester.assertNotEqual(vocab_size_3, 0)
|
self.assertNotEqual(vocab_size_3, 0)
|
||||||
tester.assertEqual(vocab_size, vocab_size_3)
|
self.assertEqual(vocab_size, vocab_size_3)
|
||||||
tester.assertEqual(added_toks_2, len(new_toks_2))
|
self.assertEqual(added_toks_2, len(new_toks_2))
|
||||||
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||||
|
|
||||||
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||||
|
|
||||||
tester.assertGreaterEqual(len(tokens), 6)
|
self.assertGreaterEqual(len(tokens), 6)
|
||||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||||
tester.assertGreater(tokens[0], tokens[1])
|
self.assertGreater(tokens[0], tokens[1])
|
||||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
tester.assertGreater(tokens[-2], tokens[-3])
|
self.assertGreater(tokens[-2], tokens[-3])
|
||||||
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
||||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
def test_required_methods_tokenizer(self):
|
||||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
tokenizer = self.get_tokenizer()
|
||||||
|
input_text, output_text = self.get_input_output_texts()
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(input_text)
|
tokens = tokenizer.tokenize(input_text)
|
||||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
ids_2 = tokenizer.encode(input_text)
|
ids_2 = tokenizer.encode(input_text)
|
||||||
tester.assertListEqual(ids, ids_2)
|
self.assertListEqual(ids, ids_2)
|
||||||
|
|
||||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||||
text_2 = tokenizer.decode(ids)
|
text_2 = tokenizer.decode(ids)
|
||||||
|
|
||||||
tester.assertEqual(text_2, output_text)
|
self.assertEqual(text_2, output_text)
|
||||||
|
|
||||||
tester.assertNotEqual(len(tokens_2), 0)
|
self.assertNotEqual(len(tokens_2), 0)
|
||||||
tester.assertIsInstance(text_2, (str, unicode))
|
self.assertIsInstance(text_2, (str, unicode))
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
def test_pretrained_model_lists(self):
|
||||||
weights_list = list(tokenizer_class.max_model_input_sizes.keys())
|
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
|
||||||
weights_lists_2 = []
|
weights_lists_2 = []
|
||||||
for file_id, map_list in tokenizer_class.pretrained_vocab_files_map.items():
|
for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
|
||||||
weights_lists_2.append(list(map_list.keys()))
|
weights_lists_2.append(list(map_list.keys()))
|
||||||
|
|
||||||
for weights_list_2 in weights_lists_2:
|
for weights_list_2 in weights_lists_2:
|
||||||
tester.assertListEqual(weights_list, weights_list_2)
|
self.assertListEqual(weights_list, weights_list_2)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
|
||||||
create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
|
||||||
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
|
||||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
|
||||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
|
||||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
|
||||||
|
|||||||
@@ -20,26 +20,33 @@ from io import open
|
|||||||
|
|
||||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
from.tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
|
tokenizer_class = TransfoXLTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TransfoXLTokenizationTest, self).setUp()
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
||||||
"running", ",", "low", "l",
|
"running", ",", "low", "l",
|
||||||
]
|
]
|
||||||
with TemporaryDirectory() as tmpdirname:
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
input_text = u"<unk> UNwanted , running"
|
input_text = u"<unk> UNwanted , running"
|
||||||
output_text = u"<unk> unwanted, running"
|
output_text = u"<unk> unwanted, running"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
|
||||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||||
|
|||||||
@@ -20,12 +20,16 @@ import json
|
|||||||
|
|
||||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
class XLMTokenizationTest(unittest.TestCase):
|
class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
tokenizer_class = XLMTokenizer
|
||||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
|
||||||
|
def setUp(self):
|
||||||
|
super(XLMTokenizationTest, self).setUp()
|
||||||
|
|
||||||
|
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"w</w>", "r</w>", "t</w>",
|
"w</w>", "r</w>", "t</w>",
|
||||||
"lo", "low", "er</w>",
|
"lo", "low", "er</w>",
|
||||||
@@ -33,20 +37,24 @@ class XLMTokenizationTest(unittest.TestCase):
|
|||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||||
|
|
||||||
with TemporaryDirectory() as tmpdirname:
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
with open(self.vocab_file, "w") as fp:
|
||||||
with open(vocab_file, "w") as fp:
|
|
||||||
fp.write(json.dumps(vocab_tokens))
|
fp.write(json.dumps(vocab_tokens))
|
||||||
with open(merges_file, "w") as fp:
|
with open(self.merges_file, "w") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return XLMTokenizer.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
input_text = u"lower newer"
|
input_text = u"lower newer"
|
||||||
output_text = u"lower newer"
|
output_text = u"lower newer"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
|
def test_full_tokenizer(self):
|
||||||
|
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
|
||||||
|
|
||||||
text = "lower"
|
text = "lower"
|
||||||
bpe_tokens = ["low", "er</w>"]
|
bpe_tokens = ["low", "er</w>"]
|
||||||
|
|||||||
@@ -19,24 +19,34 @@ import unittest
|
|||||||
|
|
||||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||||
'fixtures/test_sentencepiece.model')
|
'fixtures/test_sentencepiece.model')
|
||||||
|
|
||||||
class XLNetTokenizationTest(unittest.TestCase):
|
class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
|
tokenizer_class = XLNetTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(XLNetTokenizationTest, self).setUp()
|
||||||
|
|
||||||
|
# We have a SentencePiece fixture for testing
|
||||||
|
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return XLNetTokenizer.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
|
input_text = u"This is a test"
|
||||||
|
output_text = u"This is a test"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
|
||||||
with TemporaryDirectory() as tmpdirname:
|
|
||||||
tokenizer.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
input_text = u"This is a test"
|
|
||||||
output_text = u"This is a test"
|
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
|
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u'This is a test')
|
tokens = tokenizer.tokenize(u'This is a test')
|
||||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def whitespace_tokenize(text):
|
|||||||
class BertTokenizer(PreTrainedTokenizer):
|
class BertTokenizer(PreTrainedTokenizer):
|
||||||
r"""
|
r"""
|
||||||
Constructs a BertTokenizer.
|
Constructs a BertTokenizer.
|
||||||
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||||
|
|||||||
@@ -125,42 +125,34 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
@bos_token.setter
|
@bos_token.setter
|
||||||
def bos_token(self, value):
|
def bos_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._bos_token = value
|
self._bos_token = value
|
||||||
|
|
||||||
@eos_token.setter
|
@eos_token.setter
|
||||||
def eos_token(self, value):
|
def eos_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._eos_token = value
|
self._eos_token = value
|
||||||
|
|
||||||
@unk_token.setter
|
@unk_token.setter
|
||||||
def unk_token(self, value):
|
def unk_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._unk_token = value
|
self._unk_token = value
|
||||||
|
|
||||||
@sep_token.setter
|
@sep_token.setter
|
||||||
def sep_token(self, value):
|
def sep_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._sep_token = value
|
self._sep_token = value
|
||||||
|
|
||||||
@pad_token.setter
|
@pad_token.setter
|
||||||
def pad_token(self, value):
|
def pad_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._pad_token = value
|
self._pad_token = value
|
||||||
|
|
||||||
@cls_token.setter
|
@cls_token.setter
|
||||||
def cls_token(self, value):
|
def cls_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._cls_token = value
|
self._cls_token = value
|
||||||
|
|
||||||
@mask_token.setter
|
@mask_token.setter
|
||||||
def mask_token(self, value):
|
def mask_token(self, value):
|
||||||
self.add_tokens([value])
|
|
||||||
self._mask_token = value
|
self._mask_token = value
|
||||||
|
|
||||||
@additional_special_tokens.setter
|
@additional_special_tokens.setter
|
||||||
def additional_special_tokens(self, value):
|
def additional_special_tokens(self, value):
|
||||||
self.add_tokens(value)
|
|
||||||
self._additional_special_tokens = value
|
self._additional_special_tokens = value
|
||||||
|
|
||||||
def __init__(self, max_len=None, **kwargs):
|
def __init__(self, max_len=None, **kwargs):
|
||||||
@@ -179,6 +171,10 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||||
|
if key == 'additional_special_tokens':
|
||||||
|
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
|
||||||
|
else:
|
||||||
|
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
@@ -415,15 +411,39 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens added to the vocabulary.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
# Let's see how to add a new classification token to GPT-2
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
model = GPT2Model.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
special_tokens_dict = {'cls_token': '<CLS>'}
|
||||||
|
|
||||||
|
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
print('We have added', num_added_toks, 'tokens')
|
||||||
|
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||||
|
|
||||||
|
assert tokenizer.cls_token == '<CLS>'
|
||||||
"""
|
"""
|
||||||
if not special_tokens_dict:
|
if not special_tokens_dict:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
added_tokens = 0
|
||||||
for key, value in special_tokens_dict.items():
|
for key, value in special_tokens_dict.items():
|
||||||
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
||||||
|
if key == 'additional_special_tokens':
|
||||||
|
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
|
||||||
|
added_tokens += self.add_tokens(value)
|
||||||
|
else:
|
||||||
|
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
||||||
|
added_tokens += self.add_tokens([value])
|
||||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return added_tokens
|
||||||
|
|
||||||
def tokenize(self, text, **kwargs):
|
def tokenize(self, text, **kwargs):
|
||||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||||
|
|||||||
Reference in New Issue
Block a user