From 614fef1691edb806de976756d4948ecbcd0c0ca3 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Fri, 24 Jul 2020 15:37:52 +0200 Subject: [PATCH] Ensure OpenAI GPT position_ids is correctly initialized and registered at init. (#5773) * Ensure OpenAI GPT position_ids is correctly initialized and registered as buffer at init. This will make it compatible with TorchScript export. Signed-off-by: Morgan Funtowicz * Fix missing slice operator on the tensor data accessor. Signed-off-by: Morgan Funtowicz * Style. Signed-off-by: Morgan Funtowicz * Fixed BertEmbedding position_ids buffer created at forward. Signed-off-by: Morgan Funtowicz * Fixed MobileBertEmbedding position_ids buffer created at forward. Signed-off-by: Morgan Funtowicz * Fixed XLM position_ids buffer created at forward. Signed-off-by: Morgan Funtowicz --- src/transformers/modeling_bert.py | 11 +++++++---- src/transformers/modeling_mobilebert.py | 12 ++++++++---- src/transformers/modeling_openai.py | 5 ++--- src/transformers/modeling_xlm.py | 6 ++---- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 281ecd766f..e27ba7539c 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -180,6 +180,9 @@ class BertEmbeddings(nn.Module): self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): if input_ids is not None: input_shape = input_ids.size() @@ -187,12 +190,12 @@ class BertEmbeddings(nn.Module): input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) + position_ids = self.position_ids[:, :seq_length] + if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) diff --git a/src/transformers/modeling_mobilebert.py b/src/transformers/modeling_mobilebert.py index a32957c522..b01c29df29 100644 --- a/src/transformers/modeling_mobilebert.py +++ b/src/transformers/modeling_mobilebert.py @@ -179,18 +179,22 @@ class MobileBertEmbeddings(nn.Module): self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] + seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) + position_ids = self.position_ids[:, :seq_length] + if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index 5365b943af..e346219c3d 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -391,6 +391,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) + self.register_buffer("position_ids", torch.arange(config.n_positions)) self.init_weights() def get_input_embeddings(self): @@ -443,9 +444,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): if position_ids is None: # Code is different from when we had a single embedding matrice from position and token embeddings - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = self.position_ids[None, : input_shape[-1]] # Attention mask. if attention_mask is not None: diff --git a/src/transformers/modeling_xlm.py b/src/transformers/modeling_xlm.py index 96c287faef..932bf807a5 100644 --- a/src/transformers/modeling_xlm.py +++ b/src/transformers/modeling_xlm.py @@ -442,6 +442,7 @@ class XLMModel(XLMPreTrainedModel): self.prune_heads({int(layer): list(map(int, heads))}) self.init_weights() + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) def get_input_embeddings(self): return self.embeddings @@ -511,12 +512,9 @@ class XLMModel(XLMPreTrainedModel): # if self.is_decoder and src_enc is not None: # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] - device = input_ids.device if input_ids is not None else inputs_embeds.device - # position_ids if position_ids is None: - position_ids = torch.arange(slen, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand((bs, slen)) + position_ids = self.position_ids[:, :slen] else: assert position_ids.size() == (bs, slen) # (slen, bs) # position_ids = position_ids.transpose(0, 1)