From 0e774e57a67654ccb13e8684a2c08f7b11da9fb0 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Thu, 14 Feb 2019 08:39:58 +0100 Subject: [PATCH] Update readme Adding details on how to extract a full list of hidden states for the Transformer-XL --- README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8b7156cfba..a244ef8dac 100644 --- a/README.md +++ b/README.md @@ -624,6 +624,18 @@ This model *outputs* a tuple of (last_hidden_state, new_mems) - `last_hidden_state`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model] - `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`. +##### Extracting a list of the hidden states at each layer of the Transformer-XL from `last_hidden_state` and `new_mems`: +The `new_mems` contain all the hidden states PLUS the output of the embeddings (`new_mems[0]`). `new_mems[-1]` is the output of the hidden state of the layer below the last layer and `last_hidden_state` is the output of the last layer (i.E. the input of the softmax when we have a language modeling head on top). + +There are two differences between the shapes of `new_mems` and `last_hidden_state`: `new_mems` have transposed first dimensions and are longer (of size `self.config.mem_len`). Here is how to extract the full list of hidden states from the model output: + +```python +hidden_states, mems = model(tokens_tensor) +seq_length = hidden_states.size(1) +lower_hidden_states = list(t[-seq_length:, ...].transpose(0, 1) for t in mems) +all_hidden_states = lower_hidden_states + [hidden_states] +``` + #### 13. `TransfoXLLMHeadModel` `TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings. @@ -637,7 +649,6 @@ This model *outputs* a tuple of (last_hidden_state, new_mems) - else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens] - `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`. - ### Tokenizers: #### `BertTokenizer`