From 32aabe8c33b65dd9877a7df474c9aa84ea1fe354 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 10 Sep 2019 12:17:18 +0200 Subject: [PATCH] WIP XLNet --- pytorch_transformers/__init__.py | 4 +- pytorch_transformers/configuration_utils.py | 1 + pytorch_transformers/modeling_tf_gpt2.py | 74 +- pytorch_transformers/modeling_tf_utils.py | 63 + pytorch_transformers/modeling_tf_xlnet.py | 1121 +++++++++++++++++ .../tests/modeling_tf_common_test.py | 4 +- .../tests/modeling_tf_xlnet_test.py | 341 +++++ 7 files changed, 1540 insertions(+), 68 deletions(-) create mode 100644 pytorch_transformers/modeling_tf_xlnet.py create mode 100644 pytorch_transformers/tests/modeling_tf_xlnet_test.py diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 9546492b3c..dbe979c564 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -95,7 +95,7 @@ except (ImportError, AssertionError): if _tf_available: logger.info("TensorFlow version {} available.".format(tf.__version__)) - from .modeling_tf_utils import TFPreTrainedModel + from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelWithLMHead) @@ -107,7 +107,7 @@ if _tf_available: load_bert_pt_weights_in_tf2, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) - from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings, + from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, load_gpt2_pt_weights_in_tf2, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) diff --git a/pytorch_transformers/configuration_utils.py b/pytorch_transformers/configuration_utils.py index 7efc735d41..42346d6b6c 100644 --- a/pytorch_transformers/configuration_utils.py +++ b/pytorch_transformers/configuration_utils.py @@ -54,6 +54,7 @@ class PretrainedConfig(object): self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.torchscript = kwargs.pop('torchscript', False) + self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.pruned_heads = kwargs.pop('pruned_heads', {}) def save_pretrained(self, save_directory): diff --git a/pytorch_transformers/modeling_tf_gpt2.py b/pytorch_transformers/modeling_tf_gpt2.py index a896ee5a5f..900acb94a4 100644 --- a/pytorch_transformers/modeling_tf_gpt2.py +++ b/pytorch_transformers/modeling_tf_gpt2.py @@ -28,7 +28,8 @@ from io import open import numpy as np import tensorflow as tf -from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list +from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, + TFSequenceSummary, shape_list) from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings @@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights weight_value_tuples = [] + all_pytorch_weights = set(list(state_dict.keys())) for symbolic_weight in symbolic_weights: name = symbolic_weight.name name = name.replace(':0', '') @@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): weight_value_tuples.append((symbolic_weight, array)) - state_dict.pop(name) + all_pytorch_weights.discard(name) K.batch_set_value(weight_value_tuples) tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run - assert not state_dict, "Weights not loaded: {}".format(list(state_dict.keys())) + logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights)) return tf_model @@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer): outputs = [x] + output_attn[1:] return outputs # x, present, (attentions) -class TFGPT2Embeddings(tf.keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings. - """ - def __init__(self, config, **kwargs): - super(TFGPT2Embeddings, self).__init__(**kwargs) - self.vocab_size = config.vocab_size - self.hidden_size = config.hidden_size - - def build(self, input_shape): - """Build shared word embedding layer - Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - self.weight = self.add_weight( - "weight", - shape=[self.vocab_size, self.hidden_size], - initializer=tf.random_normal_initializer( - mean=0., stddev=self.hidden_size**-0.5)) - super(TFGPT2Embeddings, self).build(input_shape) - - def call(self, inputs, mode="embedding"): - """Get token embeddings of inputs. - Args: - inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) - mode: string, a valid value is one of "embedding" and "linear". - Returns: - outputs: (1) If mode == "embedding", output embedding tensor, float32 with - shape [batch_size, length, embedding_size]; (2) mode == "linear", output - linear tensor, float32 with shape [batch_size, length, vocab_size]. - Raises: - ValueError: if mode is not valid. - - Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - if mode == "embedding": - return self._embedding(inputs) - elif mode == "linear": - return self._linear(inputs) - else: - raise ValueError("mode {} is not valid.".format(mode)) - - def _embedding(self, input_ids): - """Applies embedding based on inputs tensor.""" - return tf.gather(self.weight, input_ids) - - def _linear(self, inputs): - """Computes logits by running inputs through a linear layer. - Args: - inputs: A float32 tensor with shape [..., hidden_size] - Returns: - float32 tensor with shape [..., vocab_size]. - """ - first_dims = shape_list(inputs)[:-1] - - x = tf.reshape(inputs, [-1, self.hidden_size]) - logits = tf.matmul(x, self.weight, transpose_b=True) - - return tf.reshape(logits, first_dims + [self.vocab_size]) class TFGPT2MainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): @@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): self.vocab_size = config.vocab_size self.n_embd = config.n_embd - self.wte = TFGPT2Embeddings(config, name='wte') + self.wte = TFSharedEmbeddings(config.vocab_size, config.hidden_size, name='wte') self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe') self.drop = tf.keras.layers.Dropout(config.embd_pdrop) - self.h = [TFBlock(config.n_ctx, config, scale=True, name='h_{}'.format(i)) for i in range(config.n_layer)] + self.h = [TFBlock(config.n_ctx, + config, + scale=True, + name='h_{}'.format(i)) for i in range(config.n_layer)] self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f') def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py index 160f97d157..1860ab4f8b 100644 --- a/pytorch_transformers/modeling_tf_utils.py +++ b/pytorch_transformers/modeling_tf_utils.py @@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer): return x +class TFSharedEmbeddings(tf.keras.layers.Layer): + """Construct shared token embeddings. + """ + def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs): + super(TFSharedEmbeddings, self).__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + + def build(self, input_shape): + """Build shared word embedding layer + Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + initializer_range = self.hidden_size**-0.5 if self.initializer_range is None else self.initializer_range + self.weight = self.add_weight( + "weight", + shape=[self.vocab_size, self.hidden_size], + initializer=tf.random_normal_initializer( + mean=0., stddev=initializer_range)) + super(TFSharedEmbeddings, self).build(input_shape) + + def call(self, inputs, mode="embedding"): + """Get token embeddings of inputs. + Args: + inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) + mode: string, a valid value is one of "embedding" and "linear". + Returns: + outputs: (1) If mode == "embedding", output embedding tensor, float32 with + shape [batch_size, length, embedding_size]; (2) mode == "linear", output + linear tensor, float32 with shape [batch_size, length, vocab_size]. + Raises: + ValueError: if mode is not valid. + + Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + if mode == "embedding": + return self._embedding(inputs) + elif mode == "linear": + return self._linear(inputs) + else: + raise ValueError("mode {} is not valid.".format(mode)) + + def _embedding(self, input_ids): + """Applies embedding based on inputs tensor.""" + return tf.gather(self.weight, input_ids) + + def _linear(self, inputs): + """Computes logits by running inputs through a linear layer. + Args: + inputs: A float32 tensor with shape [..., hidden_size] + Returns: + float32 tensor with shape [..., vocab_size]. + """ + first_dims = shape_list(inputs)[:-1] + + x = tf.reshape(inputs, [-1, self.hidden_size]) + logits = tf.matmul(x, self.weight, transpose_b=True) + + return tf.reshape(logits, first_dims + [self.vocab_size]) + + class TFSequenceSummary(tf.keras.layers.Layer): r""" Compute a single vector summary of a sequence hidden states according to various possibilities: Args of the config class: diff --git a/pytorch_transformers/modeling_tf_xlnet.py b/pytorch_transformers/modeling_tf_xlnet.py new file mode 100644 index 0000000000..fa24296d76 --- /dev/null +++ b/pytorch_transformers/modeling_tf_xlnet.py @@ -0,0 +1,1121 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 XLNet model. +""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import math +import os +import sys +from io import open + +import numpy as np +import tensorflow as tf + +from .configuration_xlnet import XLNetConfig +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .file_utils import add_start_docstrings + + +logger = logging.getLogger(__name__) + +TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { + 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-tf_model.h5", + 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-tf_model.h5", +} + + +def load_xlnet_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): + """ Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format + We use HDF5 to easily do transfer learning + (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). + """ + try: + import re + import torch + import numpy + from tensorflow.python.keras import backend as K + except ImportError: + logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see " + "https://pytorch.org/ for installation instructions.") + raise + + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info("Loading PyTorch weights from {}".format(pt_path)) + # Load pytorch model + state_dict = torch.load(pt_path, map_location='cpu') + + inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] + tf_inputs = tf.constant(inputs_list) + tfo = tf_model(tf_inputs, training=False) # build the network + + symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights + weight_value_tuples = [] + all_pytorch_weights = set(list(state_dict.keys())) + for symbolic_weight in symbolic_weights: + name = symbolic_weight.name + name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be + name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers) + name = name.replace(':0', '') + name = name.replace('layer_', 'layer/') + name = name.split('/') + name = name[1:] + + transpose = bool(name[-1] == 'kernel') + if name[-1] == 'kernel' or name[-1] == 'embeddings': + name[-1] = 'weight' + + name = '.'.join(name) + assert name in state_dict, "{} not found in PyTorch model".format(name) + array = state_dict[name].numpy() + + if transpose: + array = numpy.transpose(array) + + try: + assert list(symbolic_weight.shape) == list(array.shape) + except AssertionError as e: + e.args += (symbolic_weight.shape, array.shape) + raise e + + logger.info("Initialize TF weight {}".format(symbolic_weight.name)) + + weight_value_tuples.append((symbolic_weight, array)) + all_pytorch_weights.discard(name) + + K.batch_set_value(weight_value_tuples) + + tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run + + logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights)) + + return tf_model + + +def gelu(x): + """ Implementation of the gelu activation function. + XLNet is using OpenAI GPT's gelu + Also see https://arxiv.org/abs/1606.08415 + """ + cdf = 0.5 * (1.0 + tf.tanh( + (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) + return x * cdf + + +def swish(x): + return x * tf.sigmoid(x) + + +ACT2FN = {"gelu": tf.keras.layers.Activation(gelu), + "relu": tf.keras.activations.relu, + "swish": tf.keras.layers.Activation(swish)} + + +class TFXLNetRelativeAttention(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFXLNetRelativeAttention, self).__init__(**kwargs) + self.output_attentions = config.output_attentions + + if config.d_model % config.n_head != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.d_model, config.n_head)) + + self.n_head = config.n_head + self.d_head = config.d_head + self.d_model = config.d_model + self.scale = 1 / (config.d_head ** 0.5) + self.initializer_range = config.initializer_range + + self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm') + self.dropout = tf.keras.layers.Dropout(config.dropout) + + def build(input_shape): + initializer = tf.random_normal_initializer(mean=0., stddev=self.initializer_range) + self.q = self.add_weight(shape=(self.d_model, self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='q') + self.k = self.add_weight(shape=(self.d_model, self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='k') + self.v = self.add_weight(shape=(self.d_model, self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='v') + self.o = self.add_weight(shape=(self.d_model, self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='o') + self.r = self.add_weight(shape=(self.d_model, self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='r') + self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='r_r_bias') + self.r_s_bias = self.add_weight(shape=(self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='r_s_bias') + self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='r_w_bias') + self.seg_embed = self.add_weight(shape=(2, self.n_head, self.d_head), + initializer=initializer, + trainable=True, name='seg_embed') + super(TFXLNetRelativeAttention, self).build(input_shape) + + def prune_heads(self, heads): + raise NotImplementedError + + @staticmethod + def rel_shift(x, klen=-1): + """perform relative shift to form the relative attention score.""" + x_size = shape_list(x) + + x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3])) + x = x[1:, ...] + x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3])) + x = x[:, 0:klen, :, :] + # x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long)) + + return x + + def rel_attn_core(self, inputs, training=False): + """Core relative positional attention operations.""" + + q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask = inputs + + # content based attention score + ac = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_w_bias, k_head_h) + + # position based attention score + bd = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r) + bd = self.rel_shift(bd, klen=ac.shape[1]) + + # segment based attention score + if seg_mat is None: + ef = 0 + else: + ef = tf.einsum('ibnd,snd->ibns', q_head + self.r_s_bias, self.seg_embed) + ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) + + # merge attention scores and perform masking + attn_score = (ac + bd + ef) * self.scale + if attn_mask is not None: + # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask + if attn_mask.dtype == tf.float16: + attn_score = attn_score - 65500 * attn_mask + else: + attn_score = attn_score - 1e30 * attn_mask + + # attention probability + attn_prob = tf.softmax(attn_score, axis=1) + + if training: + attn_prob = self.dropout(attn_prob) + + # Mask heads if we want to + if head_mask is not None: + attn_prob = attn_prob * head_mask + + # attention output + attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) + + if self.output_attentions: + return attn_vec, attn_prob + + return attn_vec + + def post_attention(self, inputs, training=False): + """Post-attention processing.""" + # post-attention projection (back to `d_model`) + h, attn_vec, residual = inputs + + attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, self.o) + + if training: + attn_out = self.dropout(attn_out) + + if residual: + attn_out = attn_out + h + output = self.layer_norm(attn_out) + + return output + + def call(self, inputs, training=False): + (h, g, attn_mask_h, attn_mask_g, + r, seg_mat, mems, target_mapping, head_mask) = inputs + + if g is not None: + ###### Two-stream attention with relative positional encoding. + # content based attention score + if mems is not None and mems.dim() > 1: + cat = torch.cat([mems, h], dim=0) + else: + cat = h + + # content-based key head + k_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.k) + + # content-based value head + v_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.v) + + # position-based key head + k_head_r = torch.einsum('ibh,hnd->ibnd', r, self.r) + + ##### h-stream + # content-stream query head + q_head_h = torch.einsum('ibh,hnd->ibnd', h, self.q) + + # core attention ops + attn_vec_h = self.rel_attn_core( + [q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], + training=training) + + if self.output_attentions: + attn_vec_h, attn_prob_h = attn_vec_h + + # post processing + output_h = self.post_attention([h, attn_vec_h], training=training) + + ##### g-stream + # query-stream query head + q_head_g = torch.einsum('ibh,hnd->ibnd', g, self.q) + + # core attention ops + if target_mapping is not None: + q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping) + attn_vec_g = self.rel_attn_core( + [q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], + training=training) + + if self.output_attentions: + attn_vec_g, attn_prob_g = attn_vec_g + + attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping) + else: + attn_vec_g = self.rel_attn_core( + [q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], + training=training) + + if self.output_attentions: + attn_vec_g, attn_prob_g = attn_vec_g + + # post processing + output_g = self.post_attention([g, attn_vec_g], training=training) + + if self.output_attentions: + attn_prob = attn_prob_h, attn_prob_g + + else: + ###### Multi-head attention with relative positional encoding + if mems is not None and mems.dim() > 1: + cat = tf.concat([mems, h], dim=0) + else: + cat = h + + # content heads + q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.q) + k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.k) + v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.v) + + # positional heads + k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.r) + + # core attention ops + attn_vec = self.rel_attn_core( + [q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], + training=training) + + if self.output_attentions: + attn_vec, attn_prob = attn_vec + + # post processing + output_h = self.post_attention([h, attn_vec], training=training) + output_g = None + + outputs = (output_h, output_g) + if self.output_attentions: + outputs = outputs + (attn_prob,) + return outputs + +class TFXLNetFeedForward(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFXLNetFeedForward, self).__init__(**kwargs) + self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm') + self.layer_1 = tf.keras.layers.Dense(config.d_inner, name='layer_1') + self.layer_2 = tf.keras.layers.Dense(config.d_model, name='layer_2') + self.dropout = tf.keras.layers.Dropout(config.dropout) + if isinstance(config.ff_activation, str) or \ + (sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)): + self.activation_function = ACT2FN[config.ff_activation] + else: + self.activation_function = config.ff_activation + + def call(self, inp, training=False): + output = inp + output = self.layer_1(output) + output = self.activation_function(output) + if training: + output = self.dropout(output) + output = self.layer_2(output) + if training: + output = self.dropout(output) + output = self.layer_norm(output + inp) + return output + +class TFXLNetLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFXLNetLayer, self).__init__(**kwargs) + self.rel_attn = TFXLNetRelativeAttention(config, name='rel_attn') + self.ff = TFXLNetFeedForward(config, name='ff') + self.dropout = tf.keras.layers.Dropout(config.dropout) + + def call(self, inputs, training=False): + outputs = self.rel_attn(inputs, training=training) + output_h, output_g = outputs[:2] + + if output_g is not None: + output_g = self.ff(output_g, training=training) + output_h = self.ff(output_h, training=training) + + outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there + return outputs + + +class TFXLNetMainLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFXLNetMainLayer, self).__init__(**kwargs) + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + self.mem_len = config.mem_len + self.reuse_len = config.reuse_len + self.d_model = config.d_model + self.same_length = config.same_length + self.attn_type = config.attn_type + self.bi_data = config.bi_data + self.clamp_len = config.clamp_len + self.n_layer = config.n_layer + self.use_bfloat16 = config.use_bfloat16 + self.initializer_range = config.initializer_range + + self.word_embedding = TFSharedEmbeddings(config.n_token, config.d_model, initializer_range=config.initializer_range, name='word_embedding') + self.layer = [XLNetLayer(config, name='layer_{}'.format(i)) for i in range(config.n_layer)] + self.dropout = tf.keras.layers.Dropout(config.dropout) + + def build(input_shape): + initializer = tf.random_normal_initializer(mean=0., stddev=self.initializer_range) + self.mask_emb = self.add_weight(shape=(1, 1, config.d_model), + initializer=initializer, + trainable=True, name='mask_emb') + + def _resize_token_embeddings(self, new_num_tokens): + raise NotImplementedError + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + def create_mask(self, qlen, mlen, dtype=tf.float32): + """ + Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked. + + Args: + qlen: TODO Lysandre didn't fill + mlen: TODO Lysandre didn't fill + + :: + + same_length=False: same_length=True: + < qlen > < qlen > + ^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1] + [0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1] + qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1] + [0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1] + v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0] + + """ + attn_mask = tf.ones([qlen, qlen], dtype=dtype) + mask_u = tf.matrix_band_part(attn_mask, 0, -1) + mask_dia = tf.matrix_band_part(attn_mask, 0, 0) + attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) + ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) + if self.same_length: + mask_l = tf.matrix_band_part(attn_mask, -1, 0) + ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) + return ret + + def cache_mem(self, curr_out, prev_mem): + """cache hidden states into memory.""" + if self.mem_len is None or self.mem_len == 0: + return None + else: + if self.reuse_len is not None and self.reuse_len > 0: + curr_out = curr_out[:self.reuse_len] + + if prev_mem is None: + new_mem = curr_out[-self.mem_len:] + else: + new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:] + + return tf.stop_gradient(new_mem) + + @staticmethod + def positional_embedding(pos_seq, inv_freq, bsz=None): + sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq) + pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1) + pos_emb = pos_emb[:, None, :] + + if bsz is not None: + pos_emb = tf.tile(pos_emb, [1, bsz, 1]) + + return pos_emb + + def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None): + """create relative positional encoding.""" + freq_seq = tf.range(0, self.d_model, 2.0) + if dtype is not None and dtype != tf.float32: + freq_seq = tf.cast(freq_seq, dtype=dtype) + inv_freq = 1 / (10000 ** (freq_seq / self.d_model)) + + if self.attn_type == 'bi': + # beg, end = klen - 1, -qlen + beg, end = klen, -qlen + elif self.attn_type == 'uni': + # beg, end = klen - 1, -1 + beg, end = klen, -1 + else: + raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type)) + + if self.bi_data: + fwd_pos_seq = tf.range(beg, end, -1.0) + bwd_pos_seq = tf.range(-beg, -end, 1.0) + + if dtype is not None and dtype != tf.float32: + fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) + bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) + + if self.clamp_len > 0: + fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len) + bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) + + if bsz is not None: + # With bi_data, the batch size should be divisible by 2. + assert bsz%2 == 0 + fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2) + bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2) + else: + fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq) + bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq) + + pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) + else: + fwd_pos_seq = tf.range(beg, end, -1.0) + if dtype is not None and dtype != tf.float32: + fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) + if self.clamp_len > 0: + fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) + pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) + + return pos_emb + + def call(self, inputs, training=False): + (input_ids, attention_mask, mems, perm_mask, target_mapping, + token_type_ids, input_mask, head_mask) = inputs + # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end + # but we want a unified interface in the library with the batch size on the first dimension + # so we move here the first dimension (batch) to the end + + input_ids = tf.transpose(input_ids, perm=(0, 1)) + token_type_ids = tf.transpose(token_type_ids, perm=(0, 1)) if token_type_ids is not None else None + input_mask = tf.transpose(input_mask, perm=(0, 1)) if input_mask is not None else None + attention_mask = tf.transpose(attention_mask, perm=(0, 1)) if attention_mask is not None else None + perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None + target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None + + qlen, bsz = shape_list(input_ids)[:2] + mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0 + klen = mlen + qlen + + dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 + + ##### Attention mask + # causal attention mask + if self.attn_type == 'uni': + attn_mask = self.create_mask(qlen, mlen) + attn_mask = attn_mask[:, :, None, None] + elif self.attn_type == 'bi': + attn_mask = None + else: + raise ValueError('Unsupported attention type: {}'.format(self.attn_type)) + + # data mask: input mask & perm mask + assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " + "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." + if input_mask is None and attention_mask is not None: + input_mask = 1.0 - attention_mask + if input_mask is not None and perm_mask is not None: + data_mask = input_mask[None] + perm_mask + elif input_mask is not None and perm_mask is None: + data_mask = input_mask[None] + elif input_mask is None and perm_mask is not None: + data_mask = perm_mask + else: + data_mask = None + + if data_mask is not None: + # all mems can be attended to + mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], + dtype=dtype_float) + data_mask = tf.concat([mems_mask, data_mask], axis=1) + if attn_mask is None: + attn_mask = data_mask[:, :, :, None] + else: + attn_mask += data_mask[:, :, :, None] + + if attn_mask is not None: + attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float) + + if attn_mask is not None: + non_tgt_mask = -tf.eye(qlen, dtype=dtype_float) + non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1) + non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float) + else: + non_tgt_mask = None + + ##### Word embeddings and prepare h & g hidden states + word_emb_k = self.word_embedding(input_ids) + if training: + output_h = self.dropout(word_emb_k) + if target_mapping is not None: + word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) + # else: # We removed the inp_q input which was same as target mapping + # inp_q_ext = inp_q[:, :, None] + # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k + if training: + output_g = self.dropout(word_emb_q) + else: + output_g = None + + ##### Segment embedding + if token_type_ids is not None: + # Convert `token_type_ids` to one-hot `seg_mat` + mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) + cat_ids = tf.concat([mem_pad, token_type_ids], 0) + + # `1` indicates not in the same segment [qlen x klen x bsz] + seg_mat = tf.cast( + tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), + tf.int32) + seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float) + else: + seg_mat = None + + ##### Positional encoding + pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float) + if training: + pos_emb = self.dropout(pos_emb) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) + # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) + head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.n_layer + + new_mems = () + if mems is None: + mems = [None] * len(self.layer) + + attentions = [] + hidden_states = [] + for i, layer_module in enumerate(self.layer): + # cache new mems + new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) + if self.output_hidden_states: + hidden_states.append((output_h, output_g) if output_g is not None else output_h) + + outputs = layer_module([output_h, output_g, non_tgt_mask, attn_mask, + pos_emb, seg_mat, mems[i], target_mapping, + head_mask[i]], training=training) + output_h, output_g = outputs[:2] + if self.output_attentions: + attentions.append(outputs[2]) + + # Add last hidden state + if self.output_hidden_states: + hidden_states.append((output_h, output_g) if output_g is not None else output_h) + + if training: + output = self.dropout(output_g if output_g is not None else output_h) + + # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) + outputs = (tf.transpose(output, perm=(1, 0, 2)), new_mems) + if self.output_hidden_states: + if output_g is not None: + hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) + else: + hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) + outputs = outputs + (hidden_states,) + if self.output_attentions: + attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) + outputs = outputs + (attentions,) + + return outputs # outputs, new_mems, (hidden_states), (attentions) + + +class TFXLNetPreTrainedModel(TFPreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + config_class = XLNetConfig + pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP + load_pt_weights = load_xlnet_pt_weights_in_tf2 + base_model_prefix = "transformer" + + +XLNET_START_DOCSTRING = r""" The XLNet model was proposed in + `XLNet: Generalized Autoregressive Pretraining for Language Understanding`_ + by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. + XLnet is an extension of the Transformer-XL model pre-trained using an autoregressive method + to learn bidirectional contexts by maximizing the expected likelihood over all permutations + of the input sequence factorization order. + + The specific attention pattern can be controlled at training and test time using the `perm_mask` input. + + Do to the difficulty of training a fully auto-regressive model over various factorization order, + XLNet is pretrained using only a sub-set of the output tokens as target which are selected + with the `target_mapping` input. + + To use XLNet for sequential decoding (i.e. not in fully bi-directional setting), use the `perm_mask` and + `target_mapping` inputs to control the attention span and outputs (see examples in `examples/run_generation.py`) + + This model is a PyTorch `torch.tf.keras.layers.Layer`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`XLNet: Generalized Autoregressive Pretraining for Language Understanding`: + http://arxiv.org/abs/1906.08237 + + .. _`torch.tf.keras.layers.Layer`: + https://pytorch.org/docs/stable/nn.html#module + + Parameters: + config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +XLNET_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + XLNet is a model with relative position embeddings so you can either pad the inputs on + the right or on the left. + Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **mems**: (`optional`) + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as output by the model + (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context. + To activate mems you need to set up config.mem_len to a positive value which will be the max number of tokens in + the memory output by the model. E.g. `model = XLNetModel.from_pretrained('xlnet-base-case, mem_len=1024)` will + instantiate a model which can use up to 1024 tokens of memory (in addition to the input it self). + **perm_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, sequence_length)``: + Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``: + If ``perm_mask[k, i, j] = 0``, i attend to j in batch k; + if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k. + If None, each token attends to all the others (full bidirectional attention). + Only used during pretraining (to define factorization order) or for sequential decoding (generation). + **target_mapping**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_predict, sequence_length)``: + Mask to indicate the output tokens to use. + If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token. + Only used during pretraining for partial prediction or for sequential decoding (generation). + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + A parallel sequence of tokens (can be used to indicate various portions of the inputs). + The type indices in XLNet are NOT selected in the vocabulary, they can be arbitrary numbers and + the important thing is that they should be different for tokens which belong to different segments. + The model will compute relative segment differences from the given type indices: + 0 if the segment id of two tokens are the same, 1 if not. + **input_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. + Kept for compatibility with the original code base. + You can only uses one of `input_mask` and `attention_mask` + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. +""" + +@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.", + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) +class TFXLNetModel(TFXLNetPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the last layer of the model. + **mems**: + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. + See details in the docstring of the `mems` input above. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') + model = XLNetModel.from_pretrained('xlnet-large-cased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + def __init__(self, config, *inputs, **kwargs): + super(TFXLNetModel, self).__init__(config, *inputs, **kwargs) + self.transformer = TFBertMainLayer(config, name='transformer') + + def call(self, inputs, training=False): + outputs = self.transformer(inputs, training=training) + return outputs + + +# @add_start_docstrings("""XLNet Model with a language modeling head on top +# (linear layer with weights tied to the input embeddings). """, +# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) +# class XLNetLMHeadModel(XLNetPreTrainedModel): +# r""" +# **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: +# Labels for language modeling. +# Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` +# Indices are selected in ``[-1, 0, ..., config.vocab_size]`` +# All labels set to ``-1`` are ignored (masked), the loss is only +# computed for labels in ``[0, ..., config.vocab_size]`` + +# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: +# **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: +# Language modeling loss. +# **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` +# Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). +# **mems**: +# list of ``torch.FloatTensor`` (one for each layer): +# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model +# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. +# See details in the docstring of the `mems` input above. +# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) +# list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) +# of shape ``(batch_size, sequence_length, hidden_size)``: +# Hidden-states of the model at the output of each layer plus the initial embedding outputs. +# **attentions**: (`optional`, returned when ``config.output_attentions=True``) +# list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: +# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + +# Examples:: + +# tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') +# model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased') +# # We show how to setup inputs to predict a next token using a bi-directional context. +# input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very ")).unsqueeze(0) # We will predict the masked token +# perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float) +# perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token +# target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token +# target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token) +# outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping) +# next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] + +# """ +# def __init__(self, config, **kwargs): +# super(XLNetLMHeadModel, self).__init__(config) +# self.attn_type = config.attn_type +# self.same_length = config.same_length + +# self.transformer = XLNetModel(config) +# self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) + +# self.init_weights() +# self.tie_weights() + +# def tie_weights(self): +# """ Make sure we are sharing the embeddings +# """ +# self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding) + +# def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, +# token_type_ids=None, input_mask=None, head_mask=None, labels=None): +# transformer_outputs = self.transformer(input_ids, +# attention_mask=attention_mask, +# mems=mems, +# perm_mask=perm_mask, +# target_mapping=target_mapping, +# token_type_ids=token_type_ids, +# input_mask=input_mask, +# head_mask=head_mask) + +# logits = self.lm_loss(transformer_outputs[0]) + +# outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it + +# if labels is not None: +# # Flatten the tokens +# loss_fct = CrossEntropyLoss(ignore_index=-1) +# loss = loss_fct(logits.view(-1, logits.size(-1)), +# labels.view(-1)) +# outputs = (loss,) + outputs + +# return outputs # return (loss), logits, mems, (hidden states), (attentions) + + +# @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of +# the pooled output) e.g. for GLUE tasks. """, +# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) +# class XLNetForSequenceClassification(XLNetPreTrainedModel): +# r""" +# **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: +# Labels for computing the sequence classification/regression loss. +# Indices should be in ``[0, ..., config.num_labels - 1]``. +# If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), +# If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + +# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: +# **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: +# Classification (or regression if config.num_labels==1) loss. +# **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` +# Classification (or regression if config.num_labels==1) scores (before SoftMax). +# **mems**: +# list of ``torch.FloatTensor`` (one for each layer): +# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model +# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. +# See details in the docstring of the `mems` input above. +# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) +# list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) +# of shape ``(batch_size, sequence_length, hidden_size)``: +# Hidden-states of the model at the output of each layer plus the initial embedding outputs. +# **attentions**: (`optional`, returned when ``config.output_attentions=True``) +# list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: +# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + +# Examples:: + +# tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') +# model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased') +# input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 +# labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 +# outputs = model(input_ids, labels=labels) +# loss, logits = outputs[:2] + +# """ +# def __init__(self, config, **kwargs): +# super(XLNetForSequenceClassification, self).__init__(config) +# self.num_labels = config.num_labels + +# self.transformer = XLNetModel(config) +# self.sequence_summary = SequenceSummary(config) +# self.logits_proj = nn.Linear(config.d_model, config.num_labels) + +# self.init_weights() + +# def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, +# token_type_ids=None, input_mask=None, head_mask=None, labels=None): +# transformer_outputs = self.transformer(input_ids, +# attention_mask=attention_mask, +# mems=mems, +# perm_mask=perm_mask, +# target_mapping=target_mapping, +# token_type_ids=token_type_ids, +# input_mask=input_mask, +# head_mask=head_mask) +# output = transformer_outputs[0] + +# output = self.sequence_summary(output) +# logits = self.logits_proj(output) + +# outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it + +# if labels is not None: +# if self.num_labels == 1: +# # We are doing regression +# loss_fct = MSELoss() +# loss = loss_fct(logits.view(-1), labels.view(-1)) +# else: +# loss_fct = CrossEntropyLoss() +# loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) +# outputs = (loss,) + outputs + +# return outputs # return (loss), logits, mems, (hidden states), (attentions) + + +# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of +# the hidden-states output to compute `span start logits` and `span end logits`). """, +# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) +# class XLNetForQuestionAnswering(XLNetPreTrainedModel): +# r""" +# **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: +# Labels for position (index) of the start of the labelled span for computing the token classification loss. +# Positions are clamped to the length of the sequence (`sequence_length`). +# Position outside of the sequence are not taken into account for computing the loss. +# **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: +# Labels for position (index) of the end of the labelled span for computing the token classification loss. +# Positions are clamped to the length of the sequence (`sequence_length`). +# Position outside of the sequence are not taken into account for computing the loss. +# **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: +# Labels whether a question has an answer or no answer (SQuAD 2.0) +# **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: +# Labels for position (index) of the classification token to use as input for computing plausibility of the answer. +# **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: +# Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). +# 1.0 means token should be masked. 0.0 mean token is not masked. + +# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: +# **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: +# Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. +# **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) +# ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` +# Log probabilities for the top config.start_n_top start token possibilities (beam-search). +# **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) +# ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` +# Indices for the top config.start_n_top start token possibilities (beam-search). +# **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) +# ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` +# Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). +# **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) +# ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` +# Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). +# **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) +# ``torch.FloatTensor`` of shape ``(batch_size,)`` +# Log probabilities for the ``is_impossible`` label of the answers. +# **mems**: +# list of ``torch.FloatTensor`` (one for each layer): +# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model +# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. +# See details in the docstring of the `mems` input above. +# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) +# list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) +# of shape ``(batch_size, sequence_length, hidden_size)``: +# Hidden-states of the model at the output of each layer plus the initial embedding outputs. +# **attentions**: (`optional`, returned when ``config.output_attentions=True``) +# list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: +# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + +# Examples:: + +# tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') +# model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased') +# input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 +# start_positions = torch.tensor([1]) +# end_positions = torch.tensor([3]) +# outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) +# loss, start_scores, end_scores = outputs[:2] + +# """ +# def __init__(self, config, **kwargs): +# super(XLNetForQuestionAnswering, self).__init__(config) +# self.start_n_top = config.start_n_top +# self.end_n_top = config.end_n_top + +# self.transformer = XLNetModel(config) +# self.start_logits = PoolerStartLogits(config) +# self.end_logits = PoolerEndLogits(config) +# self.answer_class = PoolerAnswerClass(config) + +# self.init_weights() + +# def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, +# token_type_ids=None, input_mask=None, head_mask=None, +# start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,): +# transformer_outputs = self.transformer(input_ids, +# attention_mask=attention_mask, +# mems=mems, +# perm_mask=perm_mask, +# target_mapping=target_mapping, +# token_type_ids=token_type_ids, +# input_mask=input_mask, +# head_mask=head_mask) +# hidden_states = transformer_outputs[0] +# start_logits = self.start_logits(hidden_states, p_mask=p_mask) + +# outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it + +# if start_positions is not None and end_positions is not None: +# # If we are on multi-GPU, let's remove the dimension added by batch splitting +# for x in (start_positions, end_positions, cls_index, is_impossible): +# if x is not None and x.dim() > 1: +# x.squeeze_(-1) + +# # during training, compute the end logits based on the ground truth of the start position +# end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + +# loss_fct = CrossEntropyLoss() +# start_loss = loss_fct(start_logits, start_positions) +# end_loss = loss_fct(end_logits, end_positions) +# total_loss = (start_loss + end_loss) / 2 + +# if cls_index is not None and is_impossible is not None: +# # Predict answerability from the representation of CLS and START +# cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) +# loss_fct_cls = nn.BCEWithLogitsLoss() +# cls_loss = loss_fct_cls(cls_logits, is_impossible) + +# # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss +# total_loss += cls_loss * 0.5 + +# outputs = (total_loss,) + outputs + +# else: +# # during inference, compute the end logits based on beam search +# bsz, slen, hsz = hidden_states.size() +# start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) + +# start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) +# start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) +# start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) +# start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + +# hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) +# p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None +# end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) +# end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + +# end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) +# end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) +# end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + +# start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states +# cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample + +# outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs + +# # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits +# # or (if labels are provided) (total_loss,) +# return outputs diff --git a/pytorch_transformers/tests/modeling_tf_common_test.py b/pytorch_transformers/tests/modeling_tf_common_test.py index da3263ffde..3dae24b283 100644 --- a/pytorch_transformers/tests/modeling_tf_common_test.py +++ b/pytorch_transformers/tests/modeling_tf_common_test.py @@ -262,7 +262,7 @@ class TFCommonTestCases: # self.assertEqual(len(params_tied_2), len(params_tied)) -def ids_tensor(shape, vocab_size, rng=None, name=None): +def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32): """Creates a random int32 tensor of the shape within the vocab size.""" if rng is None: rng = random.Random() @@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None): for _ in range(total_dims): values.append(rng.randint(0, vocab_size - 1)) - return tf.constant(values, shape=shape) + return tf.constant(values, shape=shape, dtype=dtype) class TFModelUtilsTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/modeling_tf_xlnet_test.py b/pytorch_transformers/tests/modeling_tf_xlnet_test.py new file mode 100644 index 0000000000..58e53e1311 --- /dev/null +++ b/pytorch_transformers/tests/modeling_tf_xlnet_test.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +import json +import random +import shutil +import pytest + +from pytorch_transformers import XLNetConfig, is_tf_available + +if is_tf_available(): + import tensorflow as tf + + from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) + # XLNetLMHeadModel, + # XLNetForSequenceClassification, XLNetForQuestionAnswering) +else: + pytestmark = pytest.mark.skip("Require TensorFlow") + +from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester + +class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): + + all_model_classes=(TFXLNetModel, ) if is_tf_available() else () + # all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel, + # TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else () + test_pruning = False + + class TFXLNetModelTester(object): + + def __init__(self, + parent, + batch_size=13, + seq_length=7, + mem_len=10, + clamp_len=-1, + reuse_len=15, + is_training=True, + use_labels=True, + vocab_size=99, + cutoffs=[10, 50, 80], + hidden_size=32, + num_attention_heads=4, + d_inner=128, + num_hidden_layers=5, + max_position_embeddings=10, + type_sequence_label_size=2, + untie_r=True, + bi_data=False, + same_length=False, + initializer_range=0.05, + seed=1, + type_vocab_size=2, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.mem_len = mem_len + # self.key_len = seq_length + mem_len + self.clamp_len = clamp_len + self.reuse_len = reuse_len + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.cutoffs = cutoffs + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.d_inner = d_inner + self.num_hidden_layers = num_hidden_layers + self.max_position_embeddings = max_position_embeddings + self.bi_data = bi_data + self.untie_r = untie_r + self.same_length = same_length + self.initializer_range = initializer_range + self.seed = seed + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + + def prepare_config_and_inputs(self): + input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32) + + input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) + perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32) + perm_mask_last = tf.ones((self.batch_size, self.seq_length + 1, 1), dtype=tf.float32) + perm_mask = tf.concat([perm_mask, perm_mask_last], axis=-1) + # perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token + target_mapping = tf.zeros((self.batch_size, 1, self.seq_length), dtype=torch.float32) + target_mapping_last = tf.ones((self.batch_size, 1, 1), dtype=torch.float32) + target_mapping = tf.concat([target_mapping, target_mapping_last], axis=-1) + # target_mapping[:, 0, -1] = 1.0 # predict last token + + sequence_labels = None + lm_labels = None + is_impossible_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32) + + config = XLNetConfig( + vocab_size_or_config_json_file=self.vocab_size, + d_model=self.hidden_size, + n_head=self.num_attention_heads, + d_inner=self.d_inner, + n_layer=self.num_hidden_layers, + untie_r=self.untie_r, + max_position_embeddings=self.max_position_embeddings, + mem_len=self.mem_len, + clamp_len=self.clamp_len, + same_length=self.same_length, + reuse_len=self.reuse_len, + bi_data=self.bi_data, + initializer_range=self.initializer_range, + num_labels=self.type_sequence_label_size) + + return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels) + + def set_seed(self): + random.seed(self.seed) + tf.random.set_seed(self.seed) + + def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + model = TFXLNetModel(config) + + inputs = {'input_ids': input_ids, + 'input_mask': input_mask, + 'token_type_ids': token_type_ids} + + _, _ = model(inputs) + + inputs = [input_ids, input_mask] + + outputs, mems_1 = model(inputs) + + result = { + "mems_1": [mem.numpy() for m in mems_1], + "outputs": outputs.numpy(), + } + + self.parent.assertListEqual( + list(result["outputs"].shape), + [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertListEqual( + list(list(mem.shape) for mem in result["mems_1"]), + [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + + def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + pass + # model = XLNetLMHeadModel(config) + # model.eval() + + # loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) + + # loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1) + + # logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) + + # result = { + # "loss_1": loss_1, + # "mems_1": mems_1, + # "all_logits_1": all_logits_1, + # "loss_2": loss_2, + # "mems_2": mems_2, + # "all_logits_2": all_logits_2, + # } + + # self.parent.assertListEqual( + # list(result["loss_1"].size()), + # []) + # self.parent.assertListEqual( + # list(result["all_logits_1"].size()), + # [self.batch_size, self.seq_length, self.vocab_size]) + # self.parent.assertListEqual( + # list(list(mem.size()) for mem in result["mems_1"]), + # [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + + # self.parent.assertListEqual( + # list(result["loss_2"].size()), + # []) + # self.parent.assertListEqual( + # list(result["all_logits_2"].size()), + # [self.batch_size, self.seq_length, self.vocab_size]) + # self.parent.assertListEqual( + # list(list(mem.size()) for mem in result["mems_2"]), + # [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + + def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + pass + # model = XLNetForQuestionAnswering(config) + # model.eval() + + # outputs = model(input_ids_1) + # start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs + + # outputs = model(input_ids_1, start_positions=sequence_labels, + # end_positions=sequence_labels, + # cls_index=sequence_labels, + # is_impossible=is_impossible_labels, + # p_mask=input_mask) + + # outputs = model(input_ids_1, start_positions=sequence_labels, + # end_positions=sequence_labels, + # cls_index=sequence_labels, + # is_impossible=is_impossible_labels) + + # total_loss, mems = outputs + + # outputs = model(input_ids_1, start_positions=sequence_labels, + # end_positions=sequence_labels) + + # total_loss, mems = outputs + + # result = { + # "loss": total_loss, + # "start_top_log_probs": start_top_log_probs, + # "start_top_index": start_top_index, + # "end_top_log_probs": end_top_log_probs, + # "end_top_index": end_top_index, + # "cls_logits": cls_logits, + # "mems": mems, + # } + + # self.parent.assertListEqual( + # list(result["loss"].size()), + # []) + # self.parent.assertListEqual( + # list(result["start_top_log_probs"].size()), + # [self.batch_size, model.config.start_n_top]) + # self.parent.assertListEqual( + # list(result["start_top_index"].size()), + # [self.batch_size, model.config.start_n_top]) + # self.parent.assertListEqual( + # list(result["end_top_log_probs"].size()), + # [self.batch_size, model.config.start_n_top * model.config.end_n_top]) + # self.parent.assertListEqual( + # list(result["end_top_index"].size()), + # [self.batch_size, model.config.start_n_top * model.config.end_n_top]) + # self.parent.assertListEqual( + # list(result["cls_logits"].size()), + # [self.batch_size]) + # self.parent.assertListEqual( + # list(list(mem.size()) for mem in result["mems"]), + # [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + + def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + pass + # model = XLNetForSequenceClassification(config) + # model.eval() + + # logits, mems_1 = model(input_ids_1) + # loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels) + + # result = { + # "loss": loss, + # "mems_1": mems_1, + # "logits": logits, + # } + + # self.parent.assertListEqual( + # list(result["loss"].size()), + # []) + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.type_sequence_label_size]) + # self.parent.assertListEqual( + # list(list(mem.size()) for mem in result["mems_1"]), + # [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, + sequence_labels, is_impossible_labels) = config_and_inputs + inputs_dict = {'input_ids': input_ids_1} + return config, inputs_dict + + + def setUp(self): + self.model_tester = TFXLNetModelTest.TFXLNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_xlnet_base_model(self): + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs) + + def test_xlnet_lm_head(self): + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs) + + def test_xlnet_sequence_classif(self): + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs) + + def test_xlnet_qa(self): + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_qa(*config_and_inputs) + + @pytest.mark.slow + def test_model_from_pretrained(self): + cache_dir = "/tmp/pytorch_transformers_test/" + for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir) + shutil.rmtree(cache_dir) + self.assertIsNotNone(model) + + +if __name__ == "__main__": + unittest.main()