add easy tensor shape match test
This commit is contained in:
@@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module):
|
|||||||
for layer_past in past:
|
for layer_past in past:
|
||||||
# copy the relevant beam idx past to past
|
# copy the relevant beam idx past to past
|
||||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||||
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
|
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||||
|
# check that shape matches
|
||||||
|
assert reordered_layer_past.shape == layer_past.shape
|
||||||
|
reordered_past.append(reordered_layer_past)
|
||||||
past = tuple(reordered_past)
|
past = tuple(reordered_past)
|
||||||
|
|
||||||
# update current length
|
# update current length
|
||||||
|
|||||||
Reference in New Issue
Block a user