move back to simple space spliting
This commit is contained in:
@@ -194,7 +194,7 @@ def main():
|
|||||||
elif args.length < 0:
|
elif args.length < 0:
|
||||||
args.length = MAX_LENGTH # avoid infinite loop
|
args.length = MAX_LENGTH # avoid infinite loop
|
||||||
|
|
||||||
print(args)
|
logger.info(args)
|
||||||
if args.model_type in ["ctrl"]:
|
if args.model_type in ["ctrl"]:
|
||||||
if args.temperature > 0.7 :
|
if args.temperature > 0.7 :
|
||||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||||
|
|||||||
@@ -22,9 +22,6 @@ import os
|
|||||||
import regex as re
|
import regex as re
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import sacremoses as sm
|
|
||||||
|
|
||||||
from .tokenization_xlm import replace_unicode_punct, remove_non_printing_char
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -81,9 +78,6 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
|
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
|
||||||
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
|
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
|
||||||
|
|
||||||
self.punct_normalizer = sm.MosesPunctNormalizer(lang='en')
|
|
||||||
self.moses_tokenizer = sm.MosesTokenizer(lang='en')
|
|
||||||
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||||
@@ -138,22 +132,12 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def moses_pipeline(self, text):
|
def _tokenize(self, text):
|
||||||
text = replace_unicode_punct(text)
|
|
||||||
text = self.punct_normalizer.normalize(text)
|
|
||||||
text = remove_non_printing_char(text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
def _tokenize(self, text, bypass_tokenizer=False):
|
|
||||||
""" Tokenize a string.
|
""" Tokenize a string.
|
||||||
"""
|
"""
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
|
|
||||||
if bypass_tokenizer:
|
text = text.split(' ')
|
||||||
text = text.split()
|
|
||||||
else:
|
|
||||||
text = self.moses_pipeline(text)
|
|
||||||
text = self.moses_tokenizer.tokenize(text, return_str=False, escape=False)
|
|
||||||
|
|
||||||
for token in text:
|
for token in text:
|
||||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||||
|
|||||||
Reference in New Issue
Block a user