fix tests

This commit is contained in:
thomwolf
2019-06-26 10:02:45 +02:00
parent 092dacfd62
commit 93e9971c54
6 changed files with 38 additions and 149 deletions

View File

@@ -930,12 +930,12 @@ all_hidden_states = lower_hidden_states + [hidden_states]
`TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings.
*Inputs* are the same as the inputs of the [`TransfoXLModel`](#-12.-`TransfoXLModel`) class plus optional labels:
- `target`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the target token indices selected in the range [0, self.config.n_token[
- `labels`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the labels token indices selected in the range [0, self.config.n_token[
*Outputs* a tuple of (last_hidden_state, new_mems)
- `softmax_output`: output of the (adaptive) softmax:
- if target is None: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- else: Negative log likelihood of target tokens with shape [batch_size, sequence_length]
- if labels is None: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- else: Negative log likelihood of labels tokens with shape [batch_size, sequence_length]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
#### 14. `GPT2Model`