add easy tensor shape match test
This commit is contained in:
@@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module):
|
||||
for layer_past in past:
|
||||
# copy the relevant beam idx past to past
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
|
||||
# update current length
|
||||
|
||||
Reference in New Issue
Block a user