Fix: Fix FalconMamba training issues due to incompatible kernels (#33195)

* fix FM training kernels

* fix copies

* fix copies

* propagate to slow path

* make it BC

* add comment

* fix test
This commit is contained in:
Younes Belkada
2024-09-05 13:55:08 +04:00
committed by GitHub
parent 43df47d8e7
commit 47b096412d
6 changed files with 597 additions and 11 deletions

View File

@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .selective_scan_with_ln_interface import mamba_inner_fn

View File

@@ -0,0 +1,525 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import causal_conv1d_cuda
except ImportError:
causal_conv1d_cuda = None
import mamba_ssm
import selective_scan_cuda
# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
if hasattr(mamba_ssm.ops.triton, "layernorm"):
from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
else:
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(
ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout,
x,
out,
None,
ctx.delta_softplus,
False, # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (
du,
ddelta,
dA,
dB,
dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None,
)
def rms_norm_forward(
x,
weight,
bias,
eps=1e-6,
is_rms_norm=True,
):
# x (b l) d
if x.stride(-1) != 1:
x = x.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
# y (b l) d
return y
def selective_scan_fn(
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_ref(
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
if not is_variable_B:
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum("bdn,dn->bd", x, C)
else:
if C.dim() == 3:
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
class MambaInnerFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
"""
xz: (batch, dim, seqlen)
"""
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_bias = (
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
)
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()
if b_rms_weight is not None:
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
ctx.b_rms_weight = b_rms_weight
ctx.c_rms_weight = c_rms_weight
ctx.dt_rms_weight = dt_rms_weight
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.save_for_backward(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
)
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
@staticmethod
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
) = ctx.saved_tensors
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
if b_rms_weight is not None:
# Recompute & RMSNorm B
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
# Recompute & RMSNorm C
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
dx, dz = dxz.chunk(2, dim=1)
dout = rearrange(dout, "b l e -> e (b l)")
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
conv1d_out,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout_y,
scan_intermediates,
out,
dz,
ctx.delta_softplus,
True, # option to recompute out_z
)
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (
dxz,
dconv1d_weight,
dconv1d_bias,
dx_proj_weight,
ddelta_proj_weight,
dout_proj_weight,
dout_proj_bias,
dA,
dB,
dC,
dD,
ddelta_bias if delta_bias is not None else None,
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
dB_proj_bias,
dC_proj_bias,
None,
None,
None,
None,
None,
None,
)
def mamba_inner_fn(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
return MambaInnerFn.apply(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B,
C,
D,
delta_bias,
B_proj_bias,
C_proj_bias,
delta_softplus,
checkpoint_lvl,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
b_c_dt_rms_eps,
)

View File

@@ -23,7 +23,6 @@ from ...utils import logging
logger = logging.get_logger(__name__)
# Copied from transformers.models.mamba.configuration_mamba.MambaConfig with mamba->falcon_mamba,Mamba->FalconMamba,MAMBA->FALCON_MAMBA,state-spaces/falcon_mamba-2.8b->tiiuae/falcon-mamba-7b,use_falcon_mambapy->use_mambapy
class FalconMambaConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
@@ -82,8 +81,8 @@ class FalconMambaConfig(PretrainedConfig):
Whether or not the cache should be used.
use_mambapy (`bool`, *optional*, defaults to `False`):
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not avaiable. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
Example:
```python
@@ -127,6 +126,7 @@ class FalconMambaConfig(PretrainedConfig):
rescale_prenorm_residual=False,
use_cache=True,
use_mambapy=False,
mixer_rms_eps=1e-6,
**kwargs,
):
self.vocab_size = vocab_size
@@ -154,5 +154,6 @@ class FalconMambaConfig(PretrainedConfig):
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.use_mambapy = use_mambapy
self.mixer_rms_eps = mixer_rms_eps
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 state-spaces/falcon_mamba org and HuggingFace Inc. team.
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -45,8 +45,10 @@ else:
pscan = None
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from ...kernels.falcon_mamba import mamba_inner_fn
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
@@ -131,6 +133,15 @@ class FalconMambaMixer(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
self.register_buffer(
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
)
self.register_buffer(
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
)
self.rms_eps = config.mixer_rms_eps
if not is_fast_path_available:
if self.use_mambapy:
if is_mambapy_available():
@@ -175,6 +186,10 @@ class FalconMambaMixer(nn.Module):
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
b_rms_weight=self.b_c_rms,
c_rms_weight=self.b_c_rms,
dt_rms_weight=self.dt_rms,
b_c_dt_rms_eps=self.rms_eps,
)
else:
@@ -214,9 +229,9 @@ class FalconMambaMixer(nn.Module):
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
B = rms_forward(B)
C = rms_forward(C)
time_step = rms_forward(time_step)
B = rms_forward(B, variance_epsilon=self.rms_eps)
C = rms_forward(C, variance_epsilon=self.rms_eps)
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
# at the price of a small overhead.
@@ -315,9 +330,9 @@ class FalconMambaMixer(nn.Module):
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
B = rms_forward(B)
C = rms_forward(C)
time_step = rms_forward(time_step)
B = rms_forward(B, variance_epsilon=self.rms_eps)
C = rms_forward(C, variance_epsilon=self.rms_eps)
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(

View File

@@ -524,3 +524,32 @@ class FalconMambaIntegrationTests(unittest.TestCase):
out = tok.batch_decode(out, skip_special_tokens=True)
self.assertListEqual(out, EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_training_kernel(self):
model_id = "tiiuae/falcon-mamba-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer.pad_token_id = tokenizer.eos_token_id
text = "Hello today"
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
with torch.no_grad():
logits = torch.argmax(model(**inputs).logits, dim=-1)
out_no_training = tokenizer.batch_decode(logits)
model.train()
lm_logits = model(**inputs).logits
next_token = torch.argmax(lm_logits, dim=-1)
out_training = tokenizer.batch_decode(next_token)
# Just verify backward works
loss = (1 - lm_logits).mean()
loss.backward()
self.assertEqual(out_training, out_no_training)

View File

@@ -332,6 +332,7 @@ IGNORE_SUBMODULES = [
"modeling_attn_mask_utils",
"safetensors_conversion",
"modeling_gguf_pytorch_utils",
"kernels.falcon_mamba",
]