add past re-ordering for beam search
This commit is contained in:
@@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module):
|
|||||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||||
|
|
||||||
# re-order batch and internal states
|
# re-order batch
|
||||||
input_ids = input_ids[beam_idx, :]
|
input_ids = input_ids[beam_idx, :]
|
||||||
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
||||||
# TODO: Activate cache
|
|
||||||
# for k in cache.keys():
|
# re-order internal states
|
||||||
# if k != 'slen':
|
if past:
|
||||||
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
|
reordered_past = []
|
||||||
|
for layer_past in past:
|
||||||
|
# copy the relevant beam idx past to past
|
||||||
|
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||||
|
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
|
||||||
|
past = tuple(reordered_past)
|
||||||
|
|
||||||
# update current length
|
# update current length
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|||||||
Reference in New Issue
Block a user