clean up in tokenization
This commit is contained in:
@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
||||||
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||||
""" Load tf checkpoints in a pytorch model
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
|||||||
@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||||
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
||||||
}
|
}
|
||||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||||
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||||
}
|
}
|
||||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||||
'gpt2': 1024,
|
'gpt2': 1024,
|
||||||
@@ -263,9 +265,14 @@ class GPT2Tokenizer(object):
|
|||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||||
|
|
||||||
def decode(self, tokens, skip_special_tokens=False):
|
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||||
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
|
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
|
if clean_up_tokenization_spaces:
|
||||||
|
text = text.replace('<unk>', '')
|
||||||
|
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||||
|
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||||
|
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object):
|
|||||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||||
if clean_up_tokenization_spaces:
|
if clean_up_tokenization_spaces:
|
||||||
out_string = out_string.replace('<unk>', '')
|
out_string = out_string.replace('<unk>', '')
|
||||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
|
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||||
return out_string
|
return out_string
|
||||||
|
|||||||
Reference in New Issue
Block a user