From 9398058e19d1ca89c881890b6dad72e384cc88c6 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 16:34:28 +0100 Subject: [PATCH] add easy tensor shape match test --- src/transformers/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c0eaec9c2c..437ec8f6f0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module): for layer_past in past: # copy the relevant beam idx past to past reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] - reordered_past.append(torch.cat(reordered_layer_past, dim=1)) + reordered_layer_past = torch.cat(reordered_layer_past, dim=1) + # check that shape matches + assert reordered_layer_past.shape == layer_past.shape + reordered_past.append(reordered_layer_past) past = tuple(reordered_past) # update current length