From 0f9fc4fbde4ea95f817aaf710fecb1b898c61088 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 11 Oct 2019 15:47:08 +0200 Subject: [PATCH] adding option to desactivate past/memory outputs --- transformers/configuration_utils.py | 3 +- transformers/modeling_ctrl.py | 15 +++--- transformers/modeling_gpt2.py | 10 ++-- transformers/modeling_tf_ctrl.py | 12 +++-- transformers/modeling_tf_xlnet.py | 48 ++++++++++-------- transformers/modeling_xlnet.py | 52 +++++++++++--------- transformers/tests/modeling_tf_xlnet_test.py | 4 ++ transformers/tests/modeling_xlnet_test.py | 4 ++ 8 files changed, 93 insertions(+), 55 deletions(-) diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 9f79b85ef8..cfa6502bcd 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -53,7 +53,8 @@ class PretrainedConfig(object): self.num_labels = kwargs.pop('num_labels', 2) self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) - self.torchscript = kwargs.pop('torchscript', False) + self.output_past = kwargs.pop('output_past', True) # Not used by all models + self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.pruned_heads = kwargs.pop('pruned_heads', {}) diff --git a/transformers/modeling_ctrl.py b/transformers/modeling_ctrl.py index 9857a7ef19..55e64d318b 100644 --- a/transformers/modeling_ctrl.py +++ b/transformers/modeling_ctrl.py @@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel): def __init__(self, config): super(CTRLModel, self).__init__(config) self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.output_past = config.output_past + self.d_model_size = config.n_embd self.num_layers = config.n_layer - + self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float) - - self.output_attentions = config.output_attentions self.w = nn.Embedding(config.vocab_size, config.n_embd) - self.dropout = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([EncoderLayer(config.n_embd, config.n_head, @@ -378,7 +378,8 @@ class CTRLModel(CTRLPreTrainedModel): attention_mask=attention_mask, head_mask=head_mask[i]) hidden_states, present = outputs[:2] - presents = presents + (present,) + if self.output_past: + presents = presents + (present,) if self.output_attentions: all_attentions.append(outputs[2]) @@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - outputs = (hidden_states, presents) + outputs = (hidden_states,) + if self.output_past: + outputs = outputs + (presents,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: diff --git a/transformers/modeling_gpt2.py b/transformers/modeling_gpt2.py index 891dfc5677..0b5b83aa75 100644 --- a/transformers/modeling_gpt2.py +++ b/transformers/modeling_gpt2.py @@ -347,6 +347,7 @@ class GPT2Model(GPT2PreTrainedModel): super(GPT2Model, self).__init__(config) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions + self.output_past = config.output_past self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) @@ -440,7 +441,8 @@ class GPT2Model(GPT2PreTrainedModel): head_mask=head_mask[i]) hidden_states, present = outputs[:2] - presents = presents + (present,) + if self.output_past: + presents = presents + (present,) if self.output_attentions: all_attentions.append(outputs[2]) @@ -452,7 +454,9 @@ class GPT2Model(GPT2PreTrainedModel): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - outputs = (hidden_states, presents) + outputs = (hidden_states,) + if self.output_past: + outputs = outputs + (presents,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: @@ -460,7 +464,7 @@ class GPT2Model(GPT2PreTrainedModel): attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) outputs = outputs + (all_attentions,) - return outputs # last hidden state, presents, (all hidden_states), (attentions) + return outputs # last hidden state, (presents), (all hidden_states), (attentions) @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top diff --git a/transformers/modeling_tf_ctrl.py b/transformers/modeling_tf_ctrl.py index 95cc873448..c8d181548b 100644 --- a/transformers/modeling_tf_ctrl.py +++ b/transformers/modeling_tf_ctrl.py @@ -168,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super(TFCTRLMainLayer, self).__init__(**kwargs) self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.output_past = config.output_past + self.d_model_size = config.n_embd self.num_layers = config.n_layer self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size) - self.output_attentions = config.output_attentions self.w = TFSharedEmbeddings(config.vocab_size, config.n_embd, @@ -290,7 +292,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training) hidden_states, present = outputs[:2] - presents = presents + (present,) + + if self.output_past: + presents = presents + (present,) if self.output_attentions: all_attentions.append(outputs[2]) @@ -300,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - outputs = (hidden_states, presents) + outputs = (hidden_states,) + if self.output_past: + outputs = outputs + (presents,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: diff --git a/transformers/modeling_tf_xlnet.py b/transformers/modeling_tf_xlnet.py index 904c2f4af0..8a25be78c1 100644 --- a/transformers/modeling_tf_xlnet.py +++ b/transformers/modeling_tf_xlnet.py @@ -354,6 +354,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): super(TFXLNetMainLayer, self).__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states + self.output_past = config.output_past self.mem_len = config.mem_len self.reuse_len = config.reuse_len @@ -413,16 +414,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): def cache_mem(self, curr_out, prev_mem): """cache hidden states into memory.""" - if self.mem_len is None or self.mem_len == 0: - return None - else: - if self.reuse_len is not None and self.reuse_len > 0: - curr_out = curr_out[:self.reuse_len] + if self.reuse_len is not None and self.reuse_len > 0: + curr_out = curr_out[:self.reuse_len] - if prev_mem is None: - new_mem = curr_out[-self.mem_len:] - else: - new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len:] + if prev_mem is None: + new_mem = curr_out[-self.mem_len:] + else: + new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len:] return tf.stop_gradient(new_mem) @@ -538,8 +536,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): raise ValueError('Unsupported attention type: {}'.format(self.attn_type)) # data mask: input mask & perm mask - assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " - "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." + assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \ + "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." if input_mask is None and attention_mask is not None: input_mask = 1.0 - attention_mask if input_mask is not None and perm_mask is not None: @@ -624,7 +622,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): hidden_states = [] for i, layer_module in enumerate(self.layer): # cache new mems - new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) + if self.mem_len is not None and self.mem_len > 0 and self.output_past: + new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) if self.output_hidden_states: hidden_states.append((output_h, output_g) if output_g is not None else output_h) @@ -642,7 +641,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): output = self.dropout(output_g if output_g is not None else output_h, training=training) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) - outputs = (tf.transpose(output, perm=(1, 0, 2)), new_mems) + outputs = (tf.transpose(output, perm=(1, 0, 2)),) + + if self.mem_len is not None and self.mem_len > 0 and self.output_past: + outputs = outputs + (new_mems,) + if self.output_hidden_states: if output_g is not None: hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) @@ -653,7 +656,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) outputs = outputs + (attentions,) - return outputs # outputs, new_mems, (hidden_states), (attentions) + return outputs # outputs, (new_mems), (hidden_states), (attentions) class TFXLNetPreTrainedModel(TFPreTrainedModel): @@ -768,7 +771,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)`` Sequence of hidden-states at the last layer of the model. - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``tf.Tensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -810,7 +813,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``tf.Tensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -854,7 +857,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it - return outputs # return logits, mems, (hidden states), (attentions) + return outputs # return logits, (mems), (hidden states), (attentions) @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of @@ -865,7 +868,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)`` Classification (or regression if config.num_labels==1) scores (before SoftMax). - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``tf.Tensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -909,7 +912,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it - return outputs # return logits, mems, (hidden states), (attentions) + return outputs # return logits, (mems), (hidden states), (attentions) # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @@ -923,6 +926,11 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): Span-start scores (before SoftMax). **end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)`` Span-end scores (before SoftMax). + **mems**: (`optional`, returned when ``config.mem_len > 0``) + list of ``tf.Tensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. + See details in the docstring of the `mems` input above. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: @@ -962,7 +970,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it - return outputs # start_logits, end_logits, (hidden_states), (attentions) + return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions) # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of # the hidden-states output to compute `span start logits` and `span end logits`). """, diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index d6bb2ebd38..2f93dc3816 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -555,7 +555,7 @@ class XLNetModel(XLNetPreTrainedModel): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` Sequence of hidden-states at the last layer of the model. - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -581,6 +581,7 @@ class XLNetModel(XLNetPreTrainedModel): super(XLNetModel, self).__init__(config) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states + self.output_past = config.output_past self.mem_len = config.mem_len self.reuse_len = config.reuse_len @@ -637,16 +638,13 @@ class XLNetModel(XLNetPreTrainedModel): def cache_mem(self, curr_out, prev_mem): """cache hidden states into memory.""" - if self.mem_len is None or self.mem_len == 0: - return None - else: - if self.reuse_len is not None and self.reuse_len > 0: - curr_out = curr_out[:self.reuse_len] + if self.reuse_len is not None and self.reuse_len > 0: + curr_out = curr_out[:self.reuse_len] - if prev_mem is None: - new_mem = curr_out[-self.mem_len:] - else: - new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:] + if prev_mem is None: + new_mem = curr_out[-self.mem_len:] + else: + new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:] return new_mem.detach() @@ -817,8 +815,9 @@ class XLNetModel(XLNetPreTrainedModel): attentions = [] hidden_states = [] for i, layer_module in enumerate(self.layer): - # cache new mems - new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) + if self.mem_len is not None and self.mem_len > 0 and self.output_past: + # cache new mems + new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) if self.output_hidden_states: hidden_states.append((output_h, output_g) if output_g is not None else output_h) @@ -836,7 +835,11 @@ class XLNetModel(XLNetPreTrainedModel): output = self.dropout(output_g if output_g is not None else output_h) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) - outputs = (output.permute(1, 0, 2).contiguous(), new_mems) + outputs = (output.permute(1, 0, 2).contiguous(),) + + if self.mem_len is not None and self.mem_len > 0 and self.output_past: + outputs = outputs + (new_mems,) + if self.output_hidden_states: if output_g is not None: hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs) @@ -847,7 +850,7 @@ class XLNetModel(XLNetPreTrainedModel): attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) outputs = outputs + (attentions,) - return outputs # outputs, new_mems, (hidden_states), (attentions) + return outputs # outputs, (new_mems), (hidden_states), (attentions) @add_start_docstrings("""XLNet Model with a language modeling head on top @@ -867,7 +870,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): Language modeling loss. **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -932,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): labels.view(-1)) outputs = (loss,) + outputs - return outputs # return (loss), logits, mems, (hidden states), (attentions) + return outputs # return (loss), logits, (mems), (hidden states), (attentions) @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of @@ -951,7 +954,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): Classification (or regression if config.num_labels==1) loss. **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` Classification (or regression if config.num_labels==1) scores (before SoftMax). - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -1011,7 +1014,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs - return outputs # return (loss), logits, mems, (hidden states), (attentions) + return outputs # return (loss), logits, (mems), (hidden states), (attentions) @add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RACE/SWAG tasks. """, @@ -1046,6 +1049,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above). Classification scores (before SoftMax). + **mems**: (`optional`, returned when ``config.mem_len > 0``) + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. + See details in the docstring of the `mems` input above. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: @@ -1102,7 +1110,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): loss = loss_fct(reshaped_logits, labels.view(-1)) outputs = (loss,) + outputs - return outputs # return (loss), logits, mems, (hidden states), (attentions) + return outputs # return (loss), logits, (mems), (hidden states), (attentions) @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @@ -1126,7 +1134,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): Span-start scores (before SoftMax). **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` Span-end scores (before SoftMax). - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. @@ -1197,7 +1205,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): total_loss = (start_loss + end_loss) / 2 outputs = (total_loss,) + outputs - return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions) @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @@ -1239,7 +1247,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) ``torch.FloatTensor`` of shape ``(batch_size,)`` Log probabilities for the ``is_impossible`` label of the answers. - **mems**: + **mems**: (`optional`, returned when ``config.mem_len > 0``) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. diff --git a/transformers/tests/modeling_tf_xlnet_test.py b/transformers/tests/modeling_tf_xlnet_test.py index 6a0434938f..2c80c4fedb 100644 --- a/transformers/tests/modeling_tf_xlnet_test.py +++ b/transformers/tests/modeling_tf_xlnet_test.py @@ -161,6 +161,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): "outputs": outputs.numpy(), } + model.config.mem_len = 0 + no_mems_outputs = model(inputs) + self.parent.assertEqual(len(no_mems_outputs), 1) + self.parent.assertListEqual( list(result["outputs"].shape), [self.batch_size, self.seq_length, self.hidden_size]) diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index 10cbdaf37b..293dffabf6 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -150,6 +150,10 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): "outputs": outputs, } + model.config.mem_len = 0 + no_mems_outputs = model(input_ids_1) + self.parent.assertEqual(len(no_mems_outputs), 1) + self.parent.assertListEqual( list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size])