mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: avoid triton selective_state_update
This commit is contained in:
parent
0f124cbc52
commit
a4f1916a56
@ -1,8 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
|
||||||
from mamba_ssm.utils.generation import InferenceParams
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Optional, Tuple, Any
|
from typing import Optional, Tuple, Any
|
||||||
@ -60,6 +59,7 @@ class MambaBlock(nn.Module):
|
|||||||
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
||||||
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
||||||
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
||||||
|
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
|
||||||
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
||||||
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
||||||
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
||||||
@ -116,25 +116,32 @@ class MambaBlock(nn.Module):
|
|||||||
|
|
||||||
def step(self, hidden_states, conv_state, ssm_state):
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
# only support decoding with 1 token at a time
|
# only support decoding with 1 token at a time
|
||||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
xz = self.in_proj(hidden_states.view((1, -1)))
|
||||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||||
x = causal_conv1d_update(
|
x = causal_conv1d_update(
|
||||||
x,
|
x,
|
||||||
conv_state,
|
conv_state,
|
||||||
self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
)
|
)
|
||||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
x_db = self.x_proj(x)
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
# Don't add dt_bias here
|
dt = self.dt_proj_no_bias(dt)
|
||||||
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
||||||
y = selective_state_update(
|
dA = torch.exp(dt * self.negA)
|
||||||
ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
dB = dt * B.view(-1, B.size(-1))
|
||||||
)
|
x_shape = (-1, x.size(-1), 1)
|
||||||
out = self.out_proj(y)
|
ssm_state = (ssm_state * dA + dB * x.view(x_shape))
|
||||||
# conv and ssm are updated in place but we return them to make the control flow more explicit
|
c_shape = (C.size(0), C.size(1), -1)
|
||||||
return out.unsqueeze(1), conv_state, ssm_state
|
out_mm_shape = (C.size(0), -1)
|
||||||
|
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
||||||
|
# in-place ops
|
||||||
|
out.add_((x * self.D).to(out.dtype))
|
||||||
|
out.mul_(F.silu(z))
|
||||||
|
out = self.out_proj(out)
|
||||||
|
return out.view((1, -1, out.size(-1))), conv_state, ssm_state
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
|
Loading…
Reference in New Issue
Block a user