From cec3cdda1599541b033e07a9838386189a5d0010 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 19 Mar 2020 09:55:17 +0100 Subject: [PATCH] Fix input ids can be none attn mask (#3345) * fix issue 3289 * fix attention mask if input_ids None behavior --- src/transformers/modeling_ctrl.py | 5 ++++- src/transformers/modeling_gpt2.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index f9c6202861..69ddc407a9 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -330,8 +330,10 @@ class CTRLModel(CTRLPreTrainedModel): elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -347,7 +349,8 @@ class CTRLModel(CTRLPreTrainedModel): # Attention mask. if attention_mask is not None: - attention_mask = attention_mask.view(-1, input_shape[-1]) + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 04a95eff28..94fb3ac1db 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -402,8 +402,10 @@ class GPT2Model(GPT2PreTrainedModel): elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -424,7 +426,7 @@ class GPT2Model(GPT2PreTrainedModel): # Attention mask. if attention_mask is not None: - batch_size = input_ids.shape[0] + assert batch_size > 0, "batch_size has to be defined and > 0" attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length]