Merge pull request #2289 from patrickvonplaten/fix_effective_batch_size_lang_gen_xlm
fix bug in prepare inputs for language generation for xlm for effective batch_size > 1
This commit is contained in:
@@ -674,7 +674,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
mask_token_id = self.config.mask_token_id
|
mask_token_id = self.config.mask_token_id
|
||||||
lang_id = self.config.lang_id
|
lang_id = self.config.lang_id
|
||||||
|
|
||||||
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
effective_batch_size = input_ids.shape[0]
|
||||||
|
mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
||||||
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
||||||
if lang_id is not None:
|
if lang_id is not None:
|
||||||
langs = torch.full_like(input_ids, lang_id)
|
langs = torch.full_like(input_ids, lang_id)
|
||||||
|
|||||||
@@ -1010,18 +1010,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||||
# Add dummy token at the end (no attention on this one)
|
# Add dummy token at the end (no attention on this one)
|
||||||
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
|
|
||||||
|
effective_batch_size = input_ids.shape[0]
|
||||||
|
dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
|
||||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||||
|
|
||||||
# Build permutation mask so that previous tokens don't see last token
|
# Build permutation mask so that previous tokens don't see last token
|
||||||
|
sequence_length = input_ids.shape[1]
|
||||||
perm_mask = torch.zeros(
|
perm_mask = torch.zeros(
|
||||||
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=input_ids.device
|
(effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device
|
||||||
)
|
)
|
||||||
perm_mask[:, :, -1] = 1.0
|
perm_mask[:, :, -1] = 1.0
|
||||||
|
|
||||||
# We'll only predict the last token
|
# We'll only predict the last token
|
||||||
target_mapping = torch.zeros(
|
target_mapping = torch.zeros(
|
||||||
(input_ids.shape[0], 1, input_ids.shape[1]), dtype=torch.float, device=input_ids.device
|
(effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
|
||||||
)
|
)
|
||||||
target_mapping[0, 0, -1] = 1.0
|
target_mapping[0, 0, -1] = 1.0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user