[Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536)
* solve conflicts * improve comments
This commit is contained in:
committed by
GitHub
parent
ab5d06a094
commit
390c128592
@@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
cur_len = 1
|
||||
batch_idx = self.encoder_outputs_batch_dim_idx
|
||||
|
||||
assert (
|
||||
batch_size == encoder_outputs[0].shape[batch_idx]
|
||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
|
||||
expanded_idx = (
|
||||
batch_size == encoder_outputs[0].shape[0]
|
||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
||||
|
||||
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||
expanded_batch_idxs = (
|
||||
torch.arange(batch_size)
|
||||
.view(-1, 1)
|
||||
.repeat(1, num_beams * effective_batch_mult)
|
||||
.view(-1)
|
||||
.to(input_ids.device)
|
||||
)
|
||||
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
|
||||
# expand encoder_outputs
|
||||
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
|
||||
|
||||
else:
|
||||
encoder_outputs = None
|
||||
|
||||
Reference in New Issue
Block a user