add easy tensor shape match test

This commit is contained in:
patrickvonplaten
2019-12-25 16:34:28 +01:00
parent 90cda45e9e
commit 9398058e19

View File

@@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module):
for layer_past in past:
# copy the relevant beam idx past to past
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
# update current length