From d98a384cb0ec56e5e83d22d75a536052e21305a0 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 21:45:26 +0100 Subject: [PATCH 1/3] fix bug in prepare inputs for language generation for xlm for effective batch_size > 1 --- src/transformers/modeling_xlm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_xlm.py b/src/transformers/modeling_xlm.py index 398d385aa5..b8d9348dae 100644 --- a/src/transformers/modeling_xlm.py +++ b/src/transformers/modeling_xlm.py @@ -674,7 +674,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): mask_token_id = self.config.mask_token_id lang_id = self.config.lang_id - mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device) + effective_batch_size = input_ids.shape[0] + mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, mask_token], dim=1) if lang_id is not None: langs = torch.full_like(input_ids, lang_id) From 359dc438374b4eb90a034a1c65b785bb6c8b2e24 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 24 Dec 2019 16:33:20 +0100 Subject: [PATCH 2/3] fix effective batch_size error in prepare_inputs also for xlnet --- src/transformers/modeling_xlnet.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 1b53ba92c2..0ad312cfc5 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1010,18 +1010,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def prepare_inputs_for_generation(self, input_ids, **model_kwargs): # Add dummy token at the end (no attention on this one) - dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device) + + effective_batch_size = input_ids.shape[0] + sequence_length = input_ids.shape[1] + dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, dummy_token], dim=1) # Build permutation mask so that previous tokens don't see last token perm_mask = torch.zeros( - (input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=input_ids.device + (effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device ) perm_mask[:, :, -1] = 1.0 # We'll only predict the last token target_mapping = torch.zeros( - (input_ids.shape[0], 1, input_ids.shape[1]), dtype=torch.float, device=input_ids.device + (effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device ) target_mapping[0, 0, -1] = 1.0 From f18ac4c28e6382c9fce4e4a5c8a3c645fe4bdd98 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 24 Dec 2019 16:43:24 +0100 Subject: [PATCH 3/3] fix sequence length for prepare_inputs for xlnet --- src/transformers/modeling_xlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 0ad312cfc5..be9c41b0e5 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1012,11 +1012,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): # Add dummy token at the end (no attention on this one) effective_batch_size = input_ids.shape[0] - sequence_length = input_ids.shape[1] dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, dummy_token], dim=1) # Build permutation mask so that previous tokens don't see last token + sequence_length = input_ids.shape[1] perm_mask = torch.zeros( (effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device )