Fix: Correct tensor shape comment in Mamba modeling (#37801)
* Fix: Correct tensor shape comment in Mamba modeling * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py --------- Co-authored-by: ShadyPi <11342288+shadypi@user.noreply.gitee.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -301,10 +301,10 @@ class MambaMixer(nn.Module):
|
||||
else:
|
||||
scan_outputs = []
|
||||
for i in range(seq_len):
|
||||
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
|
||||
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
|
||||
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
|
||||
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
|
||||
scan_outputs.append(scan_output[:, :, 0])
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
|
||||
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
||||
scan_output = (scan_output * self.act(gate))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user