From 181075635d6f8d0596bf2e205fb611389c760ea4 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 21 Jun 2019 23:23:37 +0200 Subject: [PATCH] updating model loading and adding special tokens ids --- examples/generation_xlnet.py | 7 +++--- pytorch_pretrained_bert/modeling_xlnet.py | 9 +++++-- pytorch_pretrained_bert/tokenization_xlnet.py | 24 ++++++++++++++++++- xlnet | 1 + 4 files changed, 34 insertions(+), 7 deletions(-) create mode 160000 xlnet diff --git a/examples/generation_xlnet.py b/examples/generation_xlnet.py index 7d83d1bf20..e54f6a365f 100644 --- a/examples/generation_xlnet.py +++ b/examples/generation_xlnet.py @@ -6,14 +6,13 @@ import logging logging.basicConfig(level=logging.INFO) tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') -model = XLNetModel.from_pretrained('xlnet-large-cased') -model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased') +model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', attn_type='uni') -tokens = tokenizer.encode('I am very ') +tokens = tokenizer.encode('I am very happy') for i in range(len(tokens), 20): mask = torch.tensor([[[0.0] * i + [1.0]]]) logits, _ = model(torch.tensor([tokens + [0]]), - perm_mask=mask.expand(-1, i+1, -1), + # perm_mask=mask.expand(-1, i+1, -1), target_mapping=mask, inp_q=mask.squeeze(1)) output = torch.multinomial(F.softmax(logits[0, 0, :]), 1) diff --git a/pytorch_pretrained_bert/modeling_xlnet.py b/pytorch_pretrained_bert/modeling_xlnet.py index f825043e8c..a5af36ce29 100644 --- a/pytorch_pretrained_bert/modeling_xlnet.py +++ b/pytorch_pretrained_bert/modeling_xlnet.py @@ -730,12 +730,17 @@ class XLNetPreTrainedModel(nn.Module): # Load config config = XLNetConfig.from_json_file(resolved_config_file) - logger.info("Model config {}".format(config)) # Update config with kwargs if needed - for key, value in kwargs: + to_remove = [] + for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) diff --git a/pytorch_pretrained_bert/tokenization_xlnet.py b/pytorch_pretrained_bert/tokenization_xlnet.py index 3cc5053338..2fad20fb02 100644 --- a/pytorch_pretrained_bert/tokenization_xlnet.py +++ b/pytorch_pretrained_bert/tokenization_xlnet.py @@ -36,7 +36,29 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { VOCAB_NAME = 'spiece.model' SPECIAL_TOKENS_NAME = 'special_tokens.txt' -SPIECE_UNDERLINE = '▁' +SPIECE_UNDERLINE = u'▁' + +# Tokens +special_symbols = { + "" : 0, + "" : 1, + "" : 2, + "" : 3, + "" : 4, + "" : 5, + "" : 6, + "" : 7, + "" : 8, +} + +VOCAB_SIZE = 32000 +UNK_ID = special_symbols[""] +CLS_ID = special_symbols[""] +SEP_ID = special_symbols[""] +MASK_ID = special_symbols[""] +EOD_ID = special_symbols[""] + +# Segments (not really needed) SEG_ID_A = 0 SEG_ID_B = 1 SEG_ID_CLS = 2 diff --git a/xlnet b/xlnet new file mode 160000 index 0000000000..cbdedecbc7 --- /dev/null +++ b/xlnet @@ -0,0 +1 @@ +Subproject commit cbdedecbc7951fc000a1547f9feb086c34f0698b