From dc89441167d867ce9c5a9280e963ef2f755a8ba3 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 7 Oct 2019 15:37:25 +0200 Subject: [PATCH] update CTRL pytorch model --- transformers/modeling_ctrl.py | 397 +++++++++++++++++++++------------- 1 file changed, 241 insertions(+), 156 deletions(-) diff --git a/transformers/modeling_ctrl.py b/transformers/modeling_ctrl.py index 3633c2d8fb..cef2a666a4 100644 --- a/transformers/modeling_ctrl.py +++ b/transformers/modeling_ctrl.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch CTRL model.""" +""" PyTorch CTRL model.""" from __future__ import absolute_import, division, print_function, unicode_literals @@ -27,7 +27,6 @@ from io import open import numpy as np import torch import torch.nn as nn -import pdb from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter @@ -41,148 +40,168 @@ CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf- def angle_defn(pos, i, d_model_size): - angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / d_model_size) - return pos * angle_rates + angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / d_model_size) + return pos * angle_rates def positional_encoding(position, d_model_size, dtype): - # create the sinusoidal pattern for the positional encoding - angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1), torch.arange(d_model_size, dtype=dtype).unsqueeze(0), d_model_size)) - - sines = torch.sin(angle_rads[:, 0::2]) - cosines = torch.cos(angle_rads[:, 1::2]) - - pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0) - return pos_encoding + # create the sinusoidal pattern for the positional encoding + angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1), + torch.arange(d_model_size, dtype=dtype).unsqueeze(0), + d_model_size)) -def scaled_dot_product_attention(q, k, v, mask): - # calculate attention - matmul_qk = torch.matmul(q, k.permute(0,1,3,2)) - - dk = k.shape[-1] - scaled_attention_logits = matmul_qk / np.sqrt(dk) + sines = torch.sin(angle_rads[:, 0::2]) + cosines = torch.cos(angle_rads[:, 1::2]) - if mask is not None: - scaled_attention_logits += (mask * -1e4) - - attention_weights = torch.softmax(scaled_attention_logits, dim=-1) - output = torch.matmul(attention_weights, v) - - return output, attention_weights + pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0) + return pos_encoding + +def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None): + # calculate attention + matmul_qk = torch.matmul(q, k.permute(0,1,3,2)) + + dk = k.shape[-1] + scaled_attention_logits = matmul_qk / np.sqrt(dk) + + if mask is not None: + scaled_attention_logits += (mask * -1e4) + + if attention_mask is not None: + # Apply the attention mask + scaled_attention_logits = scaled_attention_logits + attention_mask + + attention_weights = torch.softmax(scaled_attention_logits, dim=-1) + + # Mask heads if we want to + if head_mask is not None: + attention_weights = attention_weights * head_mask + + output = torch.matmul(attention_weights, v) + + return output, attention_weights class MultiHeadAttention(torch.nn.Module): - def __init__(self, d_model_size, num_heads): - super(MultiHeadAttention, self).__init__() - self.num_heads = num_heads - self.d_model_size = d_model_size - - self.depth = int(d_model_size / self.num_heads) - - self.Wq = torch.nn.Linear(d_model_size, d_model_size) - self.Wk = torch.nn.Linear(d_model_size, d_model_size) - self.Wv = torch.nn.Linear(d_model_size, d_model_size) - - self.dense = torch.nn.Linear(d_model_size, d_model_size) - - def split_into_heads(self, x, batch_size): - x = x.reshape(batch_size, -1, self.num_heads, self.depth) - return x.permute([0, 2, 1, 3]) - - def forward(self, v, k, q, mask): - batch_size = q.shape[0] - - q = self.Wq(q) - k = self.Wk(k) - v = self.Wv(v) - - q = self.split_into_heads(q, batch_size) - k = self.split_into_heads(k, batch_size) - v = self.split_into_heads(v, batch_size) - output = scaled_dot_product_attention(q, k, v, mask) - scaled_attention = output[0].permute([0, 2, 1, 3]) - attn = output[1] - original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) - output = self.dense(original_size_attention) - - return output, attn + def __init__(self, d_model_size, num_heads, output_attentions=False): + super(MultiHeadAttention, self).__init__() + self.output_attentions = output_attentions + self.num_heads = num_heads + self.d_model_size = d_model_size + + self.depth = int(d_model_size / self.num_heads) + + self.Wq = torch.nn.Linear(d_model_size, d_model_size) + self.Wk = torch.nn.Linear(d_model_size, d_model_size) + self.Wv = torch.nn.Linear(d_model_size, d_model_size) + + self.dense = torch.nn.Linear(d_model_size, d_model_size) + + def split_into_heads(self, x, batch_size): + x = x.reshape(batch_size, -1, self.num_heads, self.depth) + return x.permute([0, 2, 1, 3]) + + def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None): + batch_size = q.shape[0] + + q = self.Wq(q) + k = self.Wk(k) + v = self.Wv(v) + + q = self.split_into_heads(q, batch_size) + k = self.split_into_heads(k, batch_size) + v = self.split_into_heads(v, batch_size) + if layer_past is not None: + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below + k = torch.cat((past_key, k), dim=-1) + v = torch.cat((past_value, v), dim=-2) + present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking + + output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions) + scaled_attention = output[0].permute([0, 2, 1, 3]) + attn = output[1] + original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) + output = self.dense(original_size_attention) + + return output, attn def point_wise_feed_forward_network(d_model_size, dff): - return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), torch.nn.ReLU(), torch.nn.Linear(dff, d_model_size)) + return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), + torch.nn.ReLU(), + torch.nn.Linear(dff, d_model_size)) class EncoderLayer(torch.nn.Module): - def __init__(self, d_model_size, num_heads, dff, rate=0.1): - super(EncoderLayer, self).__init__() + def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False): + super(EncoderLayer, self).__init__() - self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads) - self.ffn = point_wise_feed_forward_network(d_model_size, dff) + self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions) + self.ffn = point_wise_feed_forward_network(d_model_size, dff) - self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6) - self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6) - - self.dropout1 = torch.nn.Dropout(rate) - self.dropout2 = torch.nn.Dropout(rate) - - def forward(self, x, mask): - normed = self.layernorm1(x) - attn_output, attn = self.multi_head_attention(normed, normed, normed, mask) - attn_output = self.dropout1(attn_output) - out1 = x + attn_output + self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6) + self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6) - out2 = self.layernorm2(out1) - ffn_output = self.ffn(out2) - ffn_output = self.dropout2(ffn_output) - out2 = out1 + ffn_output - - return out2, attn + self.dropout1 = torch.nn.Dropout(rate) + self.dropout2 = torch.nn.Dropout(rate) + + def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None): + normed = self.layernorm1(x) + attn_output, attn = self.multi_head_attention(normed, normed, normed, mask, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask) + attn_output = self.dropout1(attn_output) + out1 = x + attn_output + + out2 = self.layernorm2(out1) + ffn_output = self.ffn(out2) + ffn_output = self.dropout2(ffn_output) + out2 = out1 + ffn_output + + return out2, attn class CTRLPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - config_class = CTRLConfig - pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP - base_model_prefix = "transformer" - - def __init__(self, *inputs, **kwargs): - super(CTRLPreTrainedModel, self).__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """ Initialize the weights. + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. """ - if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: + config_class = CTRLConfig + pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "transformer" + + def _init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.weight.data.fill_(1.0) CTRL_START_DOCSTRING = r""" CTRL model was proposed in - `CTRL: A Conditional Transformer Language Model for Controllable Generation`_ - by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. - It's a causal (unidirectional) transformer pre-trained using language modeling on a very large - corpus of ~140 GB of text data with the first token reserved as a control code (such as Links, Books, Wikipedia etc.). + `CTRL: A Conditional Transformer Language Model for Controllable Generation`_ + by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. + It's a causal (unidirectional) transformer pre-trained using language modeling on a very large + corpus of ~140 GB of text data with the first token reserved as a control code (such as Links, Books, Wikipedia etc.). - This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and - refer to the PyTorch documentation for all matter related to general usage and behavior. + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. - .. _`CTRL: A Conditional Transformer Language Model for Controllable Generation`: - https://www.github.com/salesforce/ctrl + .. _`CTRL: A Conditional Transformer Language Model for Controllable Generation`: + https://www.github.com/salesforce/ctrl - .. _`torch.nn.Module`: - https://pytorch.org/docs/stable/nn.html#module + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module - Parameters: - config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. + Parameters: + config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ CTRL_INPUTS_DOCSTRING = r""" Inputs: @@ -215,7 +234,7 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs: """ @add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.", - CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING) + CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING) class CTRLModel(CTRLPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: @@ -249,14 +268,18 @@ class CTRLModel(CTRLPreTrainedModel): self.num_layers = config.n_layer self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float) - + self.output_attentions = config.output_attentions self.w = nn.Embedding(config.vocab_size, config.n_embd) self.dropout = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)]) + self.h = nn.ModuleList([EncoderLayer(config.n_embd, + config.n_head, + config.dff, + config.resid_pdrop, + config.output_attentions) for _ in range(config.n_layer)]) self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.init_weights() @@ -267,44 +290,104 @@ class CTRLModel(CTRLPreTrainedModel): def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} """ for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads) - def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, - labels=None): - - embedded = self.w(input_ids) - x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded - seq_len = input_ids.shape[1] - mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device) - - x *= np.sqrt(self.d_model_size) + def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) - x += self.pos_encoding[:, :seq_len, :].to(x.device) - - x = self.dropout(x) - all_hidden_states = () - all_attentions = [] - for i in range(self.num_layers): + if past is None: + past_length = 0 + past = [None] * len(self.h) + else: + past_length = past[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(-1, input_shape[-1]) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.n_layer + + embedded = self.w(input_ids) + x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded + seq_len = input_ids.shape[1] + mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device) + + x *= np.sqrt(self.d_model_size) + + x += self.pos_encoding[:, position_ids, :].to(x.device) + + x = self.dropout(x) + + output_shape = input_shape + (x.size(-1),) + presents = () + all_hidden_states = () + all_attentions = [] + for i, (h, layer_past) in enumerate(zip(self.h, past)): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (x.view(*output_shape),) + outputs = h(x, + mask, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i]) + x, present = outputs[:2] + presents = presents + (present,) + + if self.output_attentions: + all_attentions.append(outputs[2]) + + x = self.layernorm(x) + x = x.view(*output_shape) if self.output_hidden_states: - all_hidden_states = all_hidden_states + (x,) - x, attn = self.h[i](x, mask) + all_hidden_states = all_hidden_states + (x,) + + outputs = (x, presents) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) if self.output_attentions: - all_attentions.append(attn) + # let the number of heads free (-1) so we can extract attention even after head pruning + attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] + all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) + outputs = outputs + (all_attentions,) + return outputs - x = self.layernorm(x) - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (x,) - - outputs = (x, None) - if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) - if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs - @add_start_docstrings("""The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """, CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING) @@ -357,15 +440,19 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def tie_weights(self): """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. + Export to TorchScript can't handle parameter sharing so we are cloning them instead. """ - self._tie_or_clone_weights(self.lm_head, - self.transformer.w) - #self._tie_or_clone_weights(self.lm_head.bias, - # self.transformer.w.bias) + self._tie_or_clone_weights(self.lm_head, self.transformer.w) + def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): - transformer_outputs = self.transformer(input_ids) + transformer_outputs = self.transformer(input_ids, + past=past, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask) + hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): outputs = (loss,) + outputs return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) - -