From ebd2cb8d74f62e0dd3c2ebc3411ee55d7f5a7b8c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 21 Jun 2019 21:08:44 +0200 Subject: [PATCH] update from_pretrained to load XLNetModel as well --- examples/generation_xlnet.py | 21 ++++++++ pytorch_pretrained_bert/modeling_xlnet.py | 51 ++++++++++++------- pytorch_pretrained_bert/tokenization_xlnet.py | 15 ++++++ tests/modeling_xlnet_test.py | 48 ++++++++++------- 4 files changed, 99 insertions(+), 36 deletions(-) create mode 100644 examples/generation_xlnet.py diff --git a/examples/generation_xlnet.py b/examples/generation_xlnet.py new file mode 100644 index 0000000000..7d83d1bf20 --- /dev/null +++ b/examples/generation_xlnet.py @@ -0,0 +1,21 @@ +import torch +from torch.nn import functional as F +from pytorch_pretrained_bert import XLNetModel, XLNetLMHeadModel, XLNetTokenizer + +import logging +logging.basicConfig(level=logging.INFO) + +tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') +model = XLNetModel.from_pretrained('xlnet-large-cased') +model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased') + +tokens = tokenizer.encode('I am very ') +for i in range(len(tokens), 20): + mask = torch.tensor([[[0.0] * i + [1.0]]]) + logits, _ = model(torch.tensor([tokens + [0]]), + perm_mask=mask.expand(-1, i+1, -1), + target_mapping=mask, + inp_q=mask.squeeze(1)) + output = torch.multinomial(F.softmax(logits[0, 0, :]), 1) + tokens.append(output.item()) + print(tokenizer.decode(tokens)) diff --git a/pytorch_pretrained_bert/modeling_xlnet.py b/pytorch_pretrained_bert/modeling_xlnet.py index 6b7562e48f..f825043e8c 100644 --- a/pytorch_pretrained_bert/modeling_xlnet.py +++ b/pytorch_pretrained_bert/modeling_xlnet.py @@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module): archive_file, resolved_archive_file)) logger.info("loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) + # Load config config = XLNetConfig.from_json_file(resolved_config_file) logger.info("Model config {}".format(config)) + + # Update config with kwargs if needed + for key, value in kwargs: + if hasattr(config, key): + setattr(config, key, value) + # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint - return load_tf_weights_in_xlnet(model, resolved_archive_file) + return load_tf_weights_in_xlnet(model, config, resolved_archive_file) + # Load from a PyTorch state_dict missing_keys = [] unexpected_keys = [] @@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module): if child is not None: load(child, prefix + name + '.') start_prefix = '' - if not hasattr(model, 'xlnet') and any(s.startswith('xlnet.') for s in state_dict.keys()): - start_prefix = 'xlnet.' + if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()): + start_prefix = 'transformer.' load(model, prefix=start_prefix) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( @@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel): output_h = self.dropout(word_emb_k) if inp_q is not None: if target_mapping is not None: - word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, -1) + word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1) else: inp_q_ext = inp_q[:, :, None] - word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k + word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k output_g = self.dropout(word_emb_q) else: output_g = None @@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): This can be used to compute head importance metrics. Default: False Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see XLNet paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. - It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. + inp_k: int32 Tensor in shape [bsz, len], the input token IDs. + seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. + input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. + 0 for real tokens and 1 for padding. + mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory + from previous batches. The length of the list equals n_layer. + If None, no memory is used. + perm_mask: [optional] float32 Tensor in shape [bsz, len, len]. + If perm_mask[k, i, j] = 0, i attend to j in batch k; + if perm_mask[k, i, j] = 1, i does not attend to j in batch k. + If None, each position attends to all the others. + target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len]. + If target_mapping[k, i, j] = 1, the i-th predict in batch k is + on the j-th token. + Only used during pretraining for partial prediction. + Set to None during finetuning. + inp_q: [optional] float32 Tensor in shape [bsz, len]. + 1 for tokens with losses and 0 for tokens without losses. + Only used during pretraining for two-stream attention. + Set to None during finetuning. Outputs: Tuple of (encoded_layers, pooled_output) diff --git a/pytorch_pretrained_bert/tokenization_xlnet.py b/pytorch_pretrained_bert/tokenization_xlnet.py index c9a3d40631..3cc5053338 100644 --- a/pytorch_pretrained_bert/tokenization_xlnet.py +++ b/pytorch_pretrained_bert/tokenization_xlnet.py @@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model' SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPIECE_UNDERLINE = '▁' +SEG_ID_A = 0 +SEG_ID_B = 1 +SEG_ID_CLS = 2 +SEG_ID_SEP = 3 +SEG_ID_PAD = 4 class XLNetTokenizer(object): """ @@ -52,6 +57,16 @@ class XLNetTokenizer(object): if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] special_tokens_file = None + if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is a cased model but you have not set " + "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = False + elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is an uncased model but you have set " + "`do_lower_case` to False. We are setting `do_lower_case=True` for you " + "but you may want to check this behavior.") + kwargs['do_lower_case'] = True else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) diff --git a/tests/modeling_xlnet_test.py b/tests/modeling_xlnet_test.py index 65d2c6648d..c99cfe25dd 100644 --- a/tests/modeling_xlnet_test.py +++ b/tests/modeling_xlnet_test.py @@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase): input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - # inp_k: int32 Tensor in shape [len, bsz], the input token IDs. - # seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. - # input_mask: float32 Tensor in shape [len, bsz], the input mask. + input_ids_q = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) + perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float) + perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token + target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float) + target_mapping[:, 0, -1] = 1.0 # predict last token + inp_q = target_mapping[:, 0, :].clone() # predict last token + + # inp_k: int32 Tensor in shape [bsz, len], the input token IDs. + # seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. + # input_mask: float32 Tensor in shape [bsz, len], the input mask. # 0 for real tokens and 1 for padding. - # mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory + # mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory # from previous batches. The length of the list equals n_layer. # If None, no memory is used. - # perm_mask: float32 Tensor in shape [len, len, bsz]. - # If perm_mask[i, j, k] = 0, i attend to j in batch k; - # if perm_mask[i, j, k] = 1, i does not attend to j in batch k. + # perm_mask: float32 Tensor in shape [bsz, len, len]. + # If perm_mask[k, i, j] = 0, i attend to j in batch k; + # if perm_mask[k, i, j] = 1, i does not attend to j in batch k. # If None, each position attends to all the others. - # target_mapping: float32 Tensor in shape [num_predict, len, bsz]. - # If target_mapping[i, j, k] = 1, the i-th predict in batch k is + # target_mapping: float32 Tensor in shape [bsz, num_predict, len]. + # If target_mapping[k, i, j] = 1, the i-th predict in batch k is # on the j-th token. # Only used during pretraining for partial prediction. # Set to None during finetuning. - # inp_q: float32 Tensor in shape [len, bsz]. + # inp_q: float32 Tensor in shape [bsz, len]. # 1 for tokens with losses and 0 for tokens without losses. # Only used during pretraining for two-stream attention. # Set to None during finetuning. @@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase): config.update(run_config) - return (config, input_ids_1, input_ids_2, segment_ids, lm_labels) + return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels) def set_seed(self): random.seed(self.seed) torch.manual_seed(self.seed) - def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels): + def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels): model = XLNetLMHeadModel(config) model.eval() loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels) - lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids) + all_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids) loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a) - lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b) + all_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b) + + logits, _ = model(input_ids_q, + perm_mask=perm_mask, + target_mapping=target_mapping, + inp_q=inp_q) outputs = { "loss_1": loss_1, "mems_1a": mems_1a, - "lm_logits_1": lm_logits_1, + "all_logits_1": all_logits_1, "mems_1b": mems_1b, "loss_2": loss_2, "mems_2a": mems_2a, - "lm_logits_2": lm_logits_2, + "all_logits_2": all_logits_2, "mems_2b": mems_2b, } return outputs @@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase): list(result["loss_1"].size()), []) self.parent.assertListEqual( - list(result["lm_logits_1"].size()), + list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1a"]), @@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase): list(result["loss_2"].size()), []) self.parent.assertListEqual( - list(result["lm_logits_2"].size()), + list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_2a"]),