[fsmt tokenizer] support lowercase tokenizer (#8389)
* support lowercase tokenizer * fix arg pos
This commit is contained in:
@@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
|
|||||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||||
|
|
||||||
|
# detect whether this is a do_lower_case situation, which can be derived by checking whether we
|
||||||
|
# have at least one upcase letter in the source vocab
|
||||||
|
do_lower_case = True
|
||||||
|
for k in src_vocab.keys():
|
||||||
|
if not k.islower():
|
||||||
|
do_lower_case = False
|
||||||
|
break
|
||||||
|
|
||||||
tgt_dict = Dictionary.load(tgt_dict_file)
|
tgt_dict = Dictionary.load(tgt_dict_file)
|
||||||
tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
|
tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
|
||||||
tgt_vocab_size = len(tgt_vocab)
|
tgt_vocab_size = len(tgt_vocab)
|
||||||
@@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
|
|||||||
tokenizer_conf = {
|
tokenizer_conf = {
|
||||||
"langs": [src_lang, tgt_lang],
|
"langs": [src_lang, tgt_lang],
|
||||||
"model_max_length": 1024,
|
"model_max_length": 1024,
|
||||||
|
"do_lower_case": do_lower_case,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"Generating {fsmt_tokenizer_config_file}")
|
print(f"Generating {fsmt_tokenizer_config_file}")
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||||||
File containing the vocabulary for the target language.
|
File containing the vocabulary for the target language.
|
||||||
merges_file (:obj:`str`):
|
merges_file (:obj:`str`):
|
||||||
File containing the merges.
|
File containing the merges.
|
||||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to lowercase the input when tokenizing.
|
Whether or not to lowercase the input when tokenizing.
|
||||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
@@ -186,6 +186,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||||||
src_vocab_file=None,
|
src_vocab_file=None,
|
||||||
tgt_vocab_file=None,
|
tgt_vocab_file=None,
|
||||||
merges_file=None,
|
merges_file=None,
|
||||||
|
do_lower_case=False,
|
||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
bos_token="<s>",
|
bos_token="<s>",
|
||||||
sep_token="</s>",
|
sep_token="</s>",
|
||||||
@@ -197,6 +198,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||||||
src_vocab_file=src_vocab_file,
|
src_vocab_file=src_vocab_file,
|
||||||
tgt_vocab_file=tgt_vocab_file,
|
tgt_vocab_file=tgt_vocab_file,
|
||||||
merges_file=merges_file,
|
merges_file=merges_file,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
sep_token=sep_token,
|
sep_token=sep_token,
|
||||||
@@ -207,6 +209,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||||||
self.src_vocab_file = src_vocab_file
|
self.src_vocab_file = src_vocab_file
|
||||||
self.tgt_vocab_file = tgt_vocab_file
|
self.tgt_vocab_file = tgt_vocab_file
|
||||||
self.merges_file = merges_file
|
self.merges_file = merges_file
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
# cache of sm.MosesPunctNormalizer instance
|
# cache of sm.MosesPunctNormalizer instance
|
||||||
self.cache_moses_punct_normalizer = dict()
|
self.cache_moses_punct_normalizer = dict()
|
||||||
@@ -351,6 +354,9 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||||||
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
|
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
|
||||||
lang = self.src_lang
|
lang = self.src_lang
|
||||||
|
|
||||||
|
if self.do_lower_case:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
if bypass_tokenizer:
|
if bypass_tokenizer:
|
||||||
text = text.split()
|
text = text.split()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -151,6 +151,13 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True)
|
decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True)
|
||||||
self.assertEqual(decoded_text, src_text)
|
self.assertEqual(decoded_text, src_text)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tokenizer_lower(self):
|
||||||
|
tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en", do_lower_case=True)
|
||||||
|
tokens = tokenizer.tokenize("USA is United States of America")
|
||||||
|
expected = ["us", "a</w>", "is</w>", "un", "i", "ted</w>", "st", "ates</w>", "of</w>", "am", "er", "ica</w>"]
|
||||||
|
self.assertListEqual(tokens, expected)
|
||||||
|
|
||||||
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
||||||
def test_torch_encode_plus_sent_to_model(self):
|
def test_torch_encode_plus_sent_to_model(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user