Fix: Fix FalconMamba training issues due to incompatible kernels (#33195)
* fix FM training kernels * fix copies * fix copies * propagate to slow path * make it BC * add comment * fix test
This commit is contained in:
15
src/transformers/kernels/falcon_mamba/__init__.py
Normal file
15
src/transformers/kernels/falcon_mamba/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .selective_scan_with_ln_interface import mamba_inner_fn
|
||||
@@ -0,0 +1,525 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
try:
|
||||
import causal_conv1d_cuda
|
||||
except ImportError:
|
||||
causal_conv1d_cuda = None
|
||||
|
||||
import mamba_ssm
|
||||
import selective_scan_cuda
|
||||
|
||||
|
||||
# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
|
||||
if hasattr(mamba_ssm.ops.triton, "layernorm"):
|
||||
from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
|
||||
else:
|
||||
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
|
||||
|
||||
|
||||
class SelectiveScanFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
|
||||
):
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3:
|
||||
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
||||
ctx.squeeze_B = True
|
||||
if C.dim() == 3:
|
||||
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
||||
ctx.squeeze_C = True
|
||||
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
||||
ctx.delta_softplus = delta_softplus
|
||||
ctx.has_z = z is not None
|
||||
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
||||
if not ctx.has_z:
|
||||
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
else:
|
||||
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
||||
out_z = rest[0]
|
||||
return out_z if not return_last_state else (out_z, last_state)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
if not ctx.has_z:
|
||||
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
||||
z = None
|
||||
out = None
|
||||
else:
|
||||
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan_cuda with the backward of chunk).
|
||||
# Here we just pass in None and dz will be allocated in the C++ code.
|
||||
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
dout,
|
||||
x,
|
||||
out,
|
||||
None,
|
||||
ctx.delta_softplus,
|
||||
False, # option to recompute out_z, not used here
|
||||
)
|
||||
dz = rest[0] if ctx.has_z else None
|
||||
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
||||
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
||||
return (
|
||||
du,
|
||||
ddelta,
|
||||
dA,
|
||||
dB,
|
||||
dC,
|
||||
dD if D is not None else None,
|
||||
dz,
|
||||
ddelta_bias if delta_bias is not None else None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_forward(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps=1e-6,
|
||||
is_rms_norm=True,
|
||||
):
|
||||
# x (b l) d
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
|
||||
# y (b l) d
|
||||
return y
|
||||
|
||||
|
||||
def selective_scan_fn(
|
||||
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
|
||||
):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
||||
not considered in the backward pass.
|
||||
"""
|
||||
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
||||
|
||||
|
||||
def selective_scan_ref(
|
||||
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
|
||||
):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
if A.is_complex():
|
||||
if is_variable_B:
|
||||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||||
if is_variable_C:
|
||||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||||
else:
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate))
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum("bdn,dn->bd", x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
if y.is_complex():
|
||||
y = y.real * 2
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
class MambaInnerFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
A,
|
||||
B=None,
|
||||
C=None,
|
||||
D=None,
|
||||
delta_bias=None,
|
||||
B_proj_bias=None,
|
||||
C_proj_bias=None,
|
||||
delta_softplus=True,
|
||||
checkpoint_lvl=1,
|
||||
b_rms_weight=None,
|
||||
c_rms_weight=None,
|
||||
dt_rms_weight=None,
|
||||
b_c_dt_rms_eps=1e-6,
|
||||
):
|
||||
"""
|
||||
xz: (batch, dim, seqlen)
|
||||
"""
|
||||
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
||||
assert checkpoint_lvl in [0, 1]
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
if torch.is_autocast_enabled():
|
||||
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
out_proj_bias = (
|
||||
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
|
||||
)
|
||||
if xz.stride(-1) != 1:
|
||||
xz = xz.contiguous()
|
||||
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
||||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
|
||||
# We're being very careful here about the layout, to avoid extra transposes.
|
||||
# We want delta to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
||||
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
|
||||
ctx.is_variable_B = B is None
|
||||
ctx.is_variable_C = C is None
|
||||
ctx.B_proj_bias_is_None = B_proj_bias is None
|
||||
ctx.C_proj_bias_is_None = C_proj_bias is None
|
||||
if B is None: # variable B
|
||||
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
||||
if B_proj_bias is not None:
|
||||
B = B + B_proj_bias.to(dtype=B.dtype)
|
||||
if not A.is_complex():
|
||||
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
else:
|
||||
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
||||
else:
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C is None: # variable C
|
||||
C = x_dbl[:, -d_state:] # (bl dstate)
|
||||
if C_proj_bias is not None:
|
||||
C = C + C_proj_bias.to(dtype=C.dtype)
|
||||
if not A.is_complex():
|
||||
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
else:
|
||||
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
||||
else:
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
|
||||
if b_rms_weight is not None:
|
||||
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
if c_rms_weight is not None:
|
||||
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
if dt_rms_weight is not None:
|
||||
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
||||
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
||||
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
||||
|
||||
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
||||
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
||||
)
|
||||
ctx.delta_softplus = delta_softplus
|
||||
ctx.out_proj_bias_is_None = out_proj_bias is None
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.b_rms_weight = b_rms_weight
|
||||
ctx.c_rms_weight = c_rms_weight
|
||||
ctx.dt_rms_weight = dt_rms_weight
|
||||
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
||||
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
||||
conv1d_out, delta = None, None
|
||||
ctx.save_for_backward(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_dbl,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
conv1d_out,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
delta_bias,
|
||||
scan_intermediates,
|
||||
b_rms_weight,
|
||||
c_rms_weight,
|
||||
dt_rms_weight,
|
||||
out,
|
||||
)
|
||||
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, dout):
|
||||
# dout: (batch, seqlen, dim)
|
||||
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
||||
(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_dbl,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
conv1d_out,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
delta_bias,
|
||||
scan_intermediates,
|
||||
b_rms_weight,
|
||||
c_rms_weight,
|
||||
dt_rms_weight,
|
||||
out,
|
||||
) = ctx.saved_tensors
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
if ctx.checkpoint_lvl == 1:
|
||||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
|
||||
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
|
||||
if dt_rms_weight is not None:
|
||||
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
||||
delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
|
||||
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
||||
if b_rms_weight is not None:
|
||||
# Recompute & RMSNorm B
|
||||
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
if c_rms_weight is not None:
|
||||
# Recompute & RMSNorm C
|
||||
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan_cuda with the backward of chunk).
|
||||
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
||||
dx, dz = dxz.chunk(2, dim=1)
|
||||
dout = rearrange(dout, "b l e -> e (b l)")
|
||||
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
||||
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
||||
conv1d_out,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
dout_y,
|
||||
scan_intermediates,
|
||||
out,
|
||||
dz,
|
||||
ctx.delta_softplus,
|
||||
True, # option to recompute out_z
|
||||
)
|
||||
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
||||
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
||||
dD = dD if D is not None else None
|
||||
dx_dbl = torch.empty_like(x_dbl)
|
||||
dB_proj_bias = None
|
||||
if ctx.is_variable_B:
|
||||
if not A.is_complex():
|
||||
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
||||
else:
|
||||
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
||||
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
||||
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
||||
dB = None
|
||||
dC_proj_bias = None
|
||||
if ctx.is_variable_C:
|
||||
if not A.is_complex():
|
||||
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
||||
else:
|
||||
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
||||
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
||||
dx_dbl[:, -d_state:] = dC # (bl d)
|
||||
dC = None
|
||||
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
||||
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
||||
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
||||
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
||||
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
||||
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
||||
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
||||
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
||||
# backward of conv1d with the backward of chunk).
|
||||
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
||||
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
||||
)
|
||||
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
||||
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
||||
return (
|
||||
dxz,
|
||||
dconv1d_weight,
|
||||
dconv1d_bias,
|
||||
dx_proj_weight,
|
||||
ddelta_proj_weight,
|
||||
dout_proj_weight,
|
||||
dout_proj_bias,
|
||||
dA,
|
||||
dB,
|
||||
dC,
|
||||
dD,
|
||||
ddelta_bias if delta_bias is not None else None,
|
||||
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
||||
dB_proj_bias,
|
||||
dC_proj_bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def mamba_inner_fn(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
A,
|
||||
B=None,
|
||||
C=None,
|
||||
D=None,
|
||||
delta_bias=None,
|
||||
B_proj_bias=None,
|
||||
C_proj_bias=None,
|
||||
delta_softplus=True,
|
||||
checkpoint_lvl=1,
|
||||
b_rms_weight=None,
|
||||
c_rms_weight=None,
|
||||
dt_rms_weight=None,
|
||||
b_c_dt_rms_eps=1e-6,
|
||||
):
|
||||
return MambaInnerFn.apply(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
delta_bias,
|
||||
B_proj_bias,
|
||||
C_proj_bias,
|
||||
delta_softplus,
|
||||
checkpoint_lvl,
|
||||
b_rms_weight,
|
||||
c_rms_weight,
|
||||
dt_rms_weight,
|
||||
b_c_dt_rms_eps,
|
||||
)
|
||||
@@ -23,7 +23,6 @@ from ...utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.mamba.configuration_mamba.MambaConfig with mamba->falcon_mamba,Mamba->FalconMamba,MAMBA->FALCON_MAMBA,state-spaces/falcon_mamba-2.8b->tiiuae/falcon-mamba-7b,use_falcon_mambapy->use_mambapy
|
||||
class FalconMambaConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
|
||||
@@ -82,8 +81,8 @@ class FalconMambaConfig(PretrainedConfig):
|
||||
Whether or not the cache should be used.
|
||||
use_mambapy (`bool`, *optional*, defaults to `False`):
|
||||
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not avaiable. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
||||
|
||||
|
||||
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
|
||||
Example:
|
||||
|
||||
```python
|
||||
@@ -127,6 +126,7 @@ class FalconMambaConfig(PretrainedConfig):
|
||||
rescale_prenorm_residual=False,
|
||||
use_cache=True,
|
||||
use_mambapy=False,
|
||||
mixer_rms_eps=1e-6,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@@ -154,5 +154,6 @@ class FalconMambaConfig(PretrainedConfig):
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.use_cache = use_cache
|
||||
self.use_mambapy = use_mambapy
|
||||
self.mixer_rms_eps = mixer_rms_eps
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 state-spaces/falcon_mamba org and HuggingFace Inc. team.
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -45,8 +45,10 @@ else:
|
||||
pscan = None
|
||||
|
||||
if is_mamba_ssm_available():
|
||||
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
|
||||
from ...kernels.falcon_mamba import mamba_inner_fn
|
||||
else:
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||
|
||||
@@ -131,6 +133,15 @@ class FalconMambaMixer(nn.Module):
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
||||
self.use_bias = config.use_bias
|
||||
|
||||
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
|
||||
self.register_buffer(
|
||||
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
|
||||
)
|
||||
self.rms_eps = config.mixer_rms_eps
|
||||
|
||||
if not is_fast_path_available:
|
||||
if self.use_mambapy:
|
||||
if is_mambapy_available():
|
||||
@@ -175,6 +186,10 @@ class FalconMambaMixer(nn.Module):
|
||||
self.D.float(),
|
||||
delta_bias=self.dt_proj.bias.float(),
|
||||
delta_softplus=True,
|
||||
b_rms_weight=self.b_c_rms,
|
||||
c_rms_weight=self.b_c_rms,
|
||||
dt_rms_weight=self.dt_rms,
|
||||
b_c_dt_rms_eps=self.rms_eps,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -214,9 +229,9 @@ class FalconMambaMixer(nn.Module):
|
||||
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
||||
)
|
||||
|
||||
B = rms_forward(B)
|
||||
C = rms_forward(C)
|
||||
time_step = rms_forward(time_step)
|
||||
B = rms_forward(B, variance_epsilon=self.rms_eps)
|
||||
C = rms_forward(C, variance_epsilon=self.rms_eps)
|
||||
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
|
||||
|
||||
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
||||
# at the price of a small overhead.
|
||||
@@ -315,9 +330,9 @@ class FalconMambaMixer(nn.Module):
|
||||
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
||||
)
|
||||
|
||||
B = rms_forward(B)
|
||||
C = rms_forward(C)
|
||||
time_step = rms_forward(time_step)
|
||||
B = rms_forward(B, variance_epsilon=self.rms_eps)
|
||||
C = rms_forward(C, variance_epsilon=self.rms_eps)
|
||||
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
|
||||
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
|
||||
|
||||
@@ -524,3 +524,32 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_training_kernel(self):
|
||||
model_id = "tiiuae/falcon-mamba-7b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello today"
|
||||
|
||||
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = torch.argmax(model(**inputs).logits, dim=-1)
|
||||
|
||||
out_no_training = tokenizer.batch_decode(logits)
|
||||
|
||||
model.train()
|
||||
lm_logits = model(**inputs).logits
|
||||
next_token = torch.argmax(lm_logits, dim=-1)
|
||||
|
||||
out_training = tokenizer.batch_decode(next_token)
|
||||
|
||||
# Just verify backward works
|
||||
loss = (1 - lm_logits).mean()
|
||||
loss.backward()
|
||||
|
||||
self.assertEqual(out_training, out_no_training)
|
||||
|
||||
@@ -332,6 +332,7 @@ IGNORE_SUBMODULES = [
|
||||
"modeling_attn_mask_utils",
|
||||
"safetensors_conversion",
|
||||
"modeling_gguf_pytorch_utils",
|
||||
"kernels.falcon_mamba",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user