[RoFormer] Fix some issues (#12397)
* add RoFormerTokenizerFast into AutoTokenizer * fix typo in roformer docs * make onnx export happy * update RoFormerConfig embedding_size * use jieba not rjieba * fix 12244 and make test_alignement passed * update ARCHIVE_MAP * make style & quality & fixup * update * make style & quality & fixup * make style quality fixup * update * suggestion from LysandreJik Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * make style * use rjieba Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -56,7 +56,7 @@ RoFormerTokenizer
|
|||||||
create_token_type_ids_from_sequences, save_vocabulary
|
create_token_type_ids_from_sequences, save_vocabulary
|
||||||
|
|
||||||
|
|
||||||
RobertaTokenizerFast
|
RoFormerTokenizerFast
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.RoFormerTokenizerFast
|
.. autoclass:: transformers.RoFormerTokenizerFast
|
||||||
|
|||||||
@@ -315,6 +315,10 @@ def is_datasets_available():
|
|||||||
return _datasets_available
|
return _datasets_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_rjieba_available():
|
||||||
|
return importlib.util.find_spec("rjieba") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_psutil_available():
|
def is_psutil_available():
|
||||||
return importlib.util.find_spec("psutil") is not None
|
return importlib.util.find_spec("psutil") is not None
|
||||||
|
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ if is_tokenizers_available():
|
|||||||
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
|
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
|
||||||
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
|
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
|
||||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||||
|
from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
|
||||||
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
|
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
|
||||||
from ..t5.tokenization_t5_fast import T5TokenizerFast
|
from ..t5.tokenization_t5_fast import T5TokenizerFast
|
||||||
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||||
@@ -232,6 +233,7 @@ else:
|
|||||||
ReformerTokenizerFast = None
|
ReformerTokenizerFast = None
|
||||||
RetriBertTokenizerFast = None
|
RetriBertTokenizerFast = None
|
||||||
RobertaTokenizerFast = None
|
RobertaTokenizerFast = None
|
||||||
|
RoFormerTokenizerFast = None
|
||||||
SqueezeBertTokenizerFast = None
|
SqueezeBertTokenizerFast = None
|
||||||
T5TokenizerFast = None
|
T5TokenizerFast = None
|
||||||
XLMRobertaTokenizerFast = None
|
XLMRobertaTokenizerFast = None
|
||||||
@@ -245,7 +247,7 @@ logger = logging.get_logger(__name__)
|
|||||||
TOKENIZER_MAPPING = OrderedDict(
|
TOKENIZER_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||||
(RoFormerConfig, (RoFormerTokenizer, None)),
|
(RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
|
||||||
(T5Config, (T5Tokenizer, T5TokenizerFast)),
|
(T5Config, (T5Tokenizer, T5TokenizerFast)),
|
||||||
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
|
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
|
||||||
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
|
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
|
||||||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json"
|
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
|
||||||
|
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
|
||||||
|
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
|
||||||
|
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
|
||||||
|
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
|
||||||
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,8 +47,9 @@ class RoFormerConfig(PretrainedConfig):
|
|||||||
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
|
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
|
||||||
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
|
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
|
||||||
:class:`~transformers.TFRoFormerModel`.
|
:class:`~transformers.TFRoFormerModel`.
|
||||||
embedding_size (:obj:`int`, `optional`, defaults to 768):
|
embedding_size (:obj:`int`, `optional`, defaults to None):
|
||||||
Dimensionality of the encoder layers and the pooler layer.
|
Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not
|
||||||
|
provided.
|
||||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||||
Dimension of the encoder layers and the pooler layer.
|
Dimension of the encoder layers and the pooler layer.
|
||||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||||
@@ -96,7 +101,7 @@ class RoFormerConfig(PretrainedConfig):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=50000,
|
vocab_size=50000,
|
||||||
embedding_size=768,
|
embedding_size=None,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
@@ -117,7 +122,7 @@ class RoFormerConfig(PretrainedConfig):
|
|||||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.embedding_size = embedding_size
|
self.embedding_size = hidden_size if embedding_size is None else embedding_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|||||||
@@ -60,7 +60,11 @@ _TOKENIZER_FOR_DOC = "RoFormerTokenizer"
|
|||||||
|
|
||||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"junnyu/roformer_chinese_small",
|
"junnyu/roformer_chinese_small",
|
||||||
"junnyu/roformer_chinese_base"
|
"junnyu/roformer_chinese_base",
|
||||||
|
"junnyu/roformer_chinese_char_small",
|
||||||
|
"junnyu/roformer_chinese_char_base",
|
||||||
|
"junnyu/roformer_small_discriminator",
|
||||||
|
"junnyu/roformer_small_generator"
|
||||||
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -327,9 +331,9 @@ class RoFormerSelfAttention(nn.Module):
|
|||||||
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
|
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
|
||||||
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
|
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
|
||||||
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
||||||
sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
|
sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
|
||||||
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
||||||
cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
|
cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
|
||||||
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
|
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
|
||||||
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
|
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
|
||||||
query_layer
|
query_layer
|
||||||
|
|||||||
@@ -65,7 +65,11 @@ _TOKENIZER_FOR_DOC = "RoFormerTokenizer"
|
|||||||
|
|
||||||
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"junnyu/roformer_chinese_small",
|
"junnyu/roformer_chinese_small",
|
||||||
"junnyu/roformer_chinese_base"
|
"junnyu/roformer_chinese_base",
|
||||||
|
"junnyu/roformer_chinese_char_small",
|
||||||
|
"junnyu/roformer_chinese_char_base",
|
||||||
|
"junnyu/roformer_small_discriminator",
|
||||||
|
"junnyu/roformer_small_generator"
|
||||||
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -31,15 +31,30 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
"vocab_file": {
|
"vocab_file": {
|
||||||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
|
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
|
||||||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
|
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"junnyu/roformer_chinese_small": 1536,
|
||||||
|
"junnyu/roformer_chinese_base": 1536,
|
||||||
|
"junnyu/roformer_chinese_char_small": 512,
|
||||||
|
"junnyu/roformer_chinese_char_base": 512,
|
||||||
|
"junnyu/roformer_small_discriminator": 128,
|
||||||
|
"junnyu/roformer_small_generator": 128,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
PRETRAINED_INIT_CONFIGURATION = {
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
"junnyu/roformer_chinese_small": {"do_lower_case": True},
|
"junnyu/roformer_chinese_small": {"do_lower_case": True},
|
||||||
"junnyu/roformer_chinese_base": {"do_lower_case": True},
|
"junnyu/roformer_chinese_base": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_chinese_char_small": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_chinese_char_base": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_small_discriminator": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_small_generator": {"do_lower_case": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -166,13 +181,8 @@ class RoFormerTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__ = d
|
||||||
try:
|
|
||||||
import rjieba
|
import rjieba
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"You need to install rjieba to use RoFormerTokenizer."
|
|
||||||
"See https://pypi.org/project/rjieba/ for installation."
|
|
||||||
)
|
|
||||||
self.jieba = rjieba
|
self.jieba = rjieba
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
|
|||||||
@@ -33,15 +33,30 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
"vocab_file": {
|
"vocab_file": {
|
||||||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
|
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
|
||||||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
|
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
|
||||||
|
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"junnyu/roformer_chinese_small": 1536,
|
||||||
|
"junnyu/roformer_chinese_base": 1536,
|
||||||
|
"junnyu/roformer_chinese_char_small": 512,
|
||||||
|
"junnyu/roformer_chinese_char_base": 512,
|
||||||
|
"junnyu/roformer_small_discriminator": 128,
|
||||||
|
"junnyu/roformer_small_generator": 128,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
PRETRAINED_INIT_CONFIGURATION = {
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
"junnyu/roformer_chinese_small": {"do_lower_case": True},
|
"junnyu/roformer_chinese_small": {"do_lower_case": True},
|
||||||
"junnyu/roformer_chinese_base": {"do_lower_case": True},
|
"junnyu/roformer_chinese_base": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_chinese_char_small": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_chinese_char_base": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_small_discriminator": {"do_lower_case": True},
|
||||||
|
"junnyu/roformer_small_generator": {"do_lower_case": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -41,26 +41,26 @@ class JiebaPreTokenizer:
|
|||||||
splits = []
|
splits = []
|
||||||
|
|
||||||
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
|
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
|
||||||
# for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
|
for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
|
||||||
# if token in self.vocab:
|
|
||||||
# splits.append(normalized_string.slice((start, end)))
|
|
||||||
# else:
|
|
||||||
# token_list = self.normalizers.normalize_str(token).split()
|
|
||||||
# for token in token_list:
|
|
||||||
# if token:
|
|
||||||
# end = start + len(token)
|
|
||||||
# splits.append(normalized_string.slice((start, end)))
|
|
||||||
# start = end
|
|
||||||
|
|
||||||
# this code test_alignement_methods can't pass but fast (300ms)
|
|
||||||
for token in self.jieba.cut(str(normalized_string), False):
|
|
||||||
if token in self.vocab:
|
if token in self.vocab:
|
||||||
splits.append(NormalizedString(token))
|
splits.append(normalized_string[start:end])
|
||||||
else:
|
else:
|
||||||
token_list = self.normalizers.normalize_str(token).split()
|
token_list = self.normalizers.normalize_str(token).split()
|
||||||
for token in token_list:
|
for token in token_list:
|
||||||
if token:
|
if token:
|
||||||
splits.append(NormalizedString(token))
|
end = start + len(token)
|
||||||
|
splits.append(normalized_string[start:end])
|
||||||
|
start = end
|
||||||
|
|
||||||
|
# this code test_alignement_methods can't pass but fast (300ms)
|
||||||
|
# for token in self.jieba.cut(str(normalized_string), False):
|
||||||
|
# if token in self.vocab:
|
||||||
|
# splits.append(NormalizedString(token))
|
||||||
|
# else:
|
||||||
|
# token_list = self.normalizers.normalize_str(token).split()
|
||||||
|
# for token in token_list:
|
||||||
|
# if token:
|
||||||
|
# splits.append(NormalizedString(token))
|
||||||
|
|
||||||
return splits
|
return splits
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .file_utils import (
|
|||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
|
is_rjieba_available,
|
||||||
is_scatter_available,
|
is_scatter_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_soundfile_availble,
|
is_soundfile_availble,
|
||||||
@@ -223,6 +224,16 @@ def require_git_lfs(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_rjieba(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||||
|
"""
|
||||||
|
if not is_rjieba_available():
|
||||||
|
return unittest.skip("test requires rjieba")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_onnx(test_case):
|
def require_onnx(test_case):
|
||||||
if not is_onnx_available():
|
if not is_onnx_available():
|
||||||
return unittest.skip("test requires ONNX")(test_case)
|
return unittest.skip("test requires ONNX")(test_case)
|
||||||
|
|||||||
@@ -13,29 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
|
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
|
||||||
from transformers.testing_utils import require_tokenizers
|
from transformers.testing_utils import require_rjieba, require_tokenizers
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
def is_rjieba_available():
|
|
||||||
return importlib.util.find_spec("rjieba") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def require_rjieba(test_case):
|
|
||||||
"""
|
|
||||||
Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
|
|
||||||
"""
|
|
||||||
if not is_rjieba_available():
|
|
||||||
return unittest.skip("test requires rjieba")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
@require_rjieba
|
@require_rjieba
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
@@ -79,6 +64,10 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
|
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
|
||||||
|
|
||||||
# due to custom pre_tokenize , char_to_token may be error
|
# can't train new_tokenizer via Tokenizers lib
|
||||||
def test_alignement_methods(self):
|
def test_training_new_tokenizer(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# can't train new_tokenizer via Tokenizers lib
|
||||||
|
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user