add past re-ordering for beam search

This commit is contained in:
patrickvonplaten
2019-12-25 16:29:20 +01:00
parent 6bca56fdb0
commit 90cda45e9e

View File

@@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module):
beam_words = input_ids.new([x[1] for x in next_batch_beam]) beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam]) beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states # re-order batch
input_ids = input_ids[beam_idx, :] input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1) input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys(): # re-order internal states
# if k != 'slen': if past:
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) reordered_past = []
for layer_past in past:
# copy the relevant beam idx past to past
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
past = tuple(reordered_past)
# update current length # update current length
cur_len = cur_len + 1 cur_len = cur_len + 1