Ensure OpenAI GPT position_ids is correctly initialized and registered at init. (#5773)
* Ensure OpenAI GPT position_ids is correctly initialized and registered as buffer at init. This will make it compatible with TorchScript export. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix missing slice operator on the tensor data accessor. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Style. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fixed BertEmbedding position_ids buffer created at forward. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fixed MobileBertEmbedding position_ids buffer created at forward. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fixed XLM position_ids buffer created at forward. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
@@ -179,18 +179,22 @@ class MobileBertEmbeddings(nn.Module):
|
||||
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user