From 90cda45e9e7da95f9084ceca6d631f64173b69c8 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 16:29:20 +0100 Subject: [PATCH] add past re-ordering for beam search --- src/transformers/modeling_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f81bcbecae..c0eaec9c2c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module): beam_words = input_ids.new([x[1] for x in next_batch_beam]) beam_idx = input_ids.new([x[2] for x in next_batch_beam]) - # re-order batch and internal states + # re-order batch input_ids = input_ids[beam_idx, :] input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1) - # TODO: Activate cache - # for k in cache.keys(): - # if k != 'slen': - # cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) + + # re-order internal states + if past: + reordered_past = [] + for layer_past in past: + # copy the relevant beam idx past to past + reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] + reordered_past.append(torch.cat(reordered_layer_past, dim=1)) + past = tuple(reordered_past) # update current length cur_len = cur_len + 1