fix: fix CohereForAI/c4ai-command-r-plus

This commit is contained in:
OlivierDehaene 2024-04-04 18:46:51 +02:00
parent 106d8ee818
commit 5088005908
22 changed files with 116 additions and 588 deletions

View File

@ -160,11 +160,6 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile COPY server/Makefile-selective-scan Makefile
RUN make build-all RUN make build-all
# Build megablocks
FROM kernel-builder as megablocks-builder
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
@ -186,9 +181,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \ curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch and Megablocks installed
COPY --from=megablocks-builder /opt/conda /opt/conda
# Copy build artifacts from flash attention builder # Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

View File

@ -17,9 +17,6 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
install: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69

View File

@ -4,7 +4,7 @@ vllm-cuda:
git clone https://github.com/vllm-project/vllm.git vllm git clone https://github.com/vllm-project/vllm.git vllm
build-vllm-cuda: vllm-cuda build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc cd vllm && git fetch && git checkout b7782002e1da25de77e0b1890ff8b72dd4df917c
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda

View File

@ -43,7 +43,7 @@ class CacheManager:
] ]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange( self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32 0, num_blocks * self.block_size, dtype=torch.int64
).view(num_blocks, self.block_size) ).view(num_blocks, self.block_size)
def allocate( def allocate(

View File

@ -61,6 +61,7 @@ class CohereConfig(PretrainedConfig):
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
logit_scale=1.0, logit_scale=1.0,
use_qk_norm=False,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -84,6 +85,7 @@ class CohereConfig(PretrainedConfig):
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.logit_scale = logit_scale self.logit_scale = logit_scale
self.use_qk_norm = use_qk_norm
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -175,16 +177,28 @@ class FlashCohereAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.layer_norm_eps,
)
else:
self.q_norm = None
self.k_norm = None
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=config.attention_bias, bias=config.attention_bias,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -199,21 +213,25 @@ class FlashCohereAttention(torch.nn.Module):
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, key, value = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads, self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
], ],
dim=1, dim=1,
) )
if self.use_qk_norm:
query = self.q_norm(query)
key = self.k_norm(key)
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size)
value = key.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, key, cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -223,8 +241,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(
query, query,
torch.select(kv, dim=1, index=0), key,
torch.select(kv, dim=1, index=1), value,
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
@ -235,9 +253,9 @@ class FlashCohereAttention(torch.nn.Module):
paged_attention.attention( paged_attention.attention(
attn_output, attn_output,
query, query,
kv_cache[0], key,
kv_cache[1], value,
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -16,14 +16,13 @@
import torch import torch
import torch.distributed import torch.distributed
import numpy as np
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger from loguru import logger
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
@ -37,14 +36,6 @@ from text_generation_server.utils.layers import (
) )
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
HAS_MEGABLOCKS = True
try:
import stk
import megablocks.ops as ops
except ImportError:
logger.warning("Dbrx: megablocks is not installed")
HAS_MEGABLOCKS = False
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
def __init__( def __init__(
@ -384,10 +375,6 @@ class DbrxAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -443,7 +430,7 @@ class DbrxAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -531,18 +518,6 @@ def round_up(x: torch.Tensor, value: int):
class BlockSparseMoE(nn.Module): class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: DbrxConfig, weights): def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__() super().__init__()
self.moe_normalize_expert_weights = ( self.moe_normalize_expert_weights = (
@ -572,241 +547,40 @@ class BlockSparseMoE(nn.Module):
) )
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights) w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights) self.num_experts, self.ffn_dim, self.hidden_dim
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights) )
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
self.offsets = None self.num_experts, self.ffn_dim, self.hidden_dim
self.offsets_block_rows = 0 )
self.wv1 = torch.cat([w1, v1], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices def forward(self, x: torch.Tensor) -> torch.Tensor:
# so that we can pass it to radix sort. # router_logits: (num_tokens, n_experts)
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) router_logits = self.gate(x)
self.blocking = 128 out = fused_moe(
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(
gate_logits, self.top_k, self.moe_normalize_expert_weights
)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and v1,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.v1.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
x, x,
indices, self.wv1,
bin_ids, self.w2,
weights, router_logits,
bins,
padded_bins,
self.top_k, self.top_k,
self.quantize_scatter_num_bits, renormalize=self.moe_normalize_expert_weights,
).view(*input_shape) inplace=True,
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
weights,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
) )
# Mask not selected experts
weights.scatter_(1, not_selected_experts, 0)
# Re-normalize
if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out.view(*x.shape)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
class DenseMoE(nn.Module): class DenseMoE(nn.Module):

View File

@ -188,10 +188,6 @@ class FlashGemmaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -244,7 +240,7 @@ class FlashGemmaAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -176,10 +176,6 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -232,7 +228,7 @@ class FlashLlamaAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -173,10 +173,6 @@ class MistralAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -236,7 +232,7 @@ class MistralAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -24,6 +24,7 @@ import torch.distributed
import numpy as np import numpy as np
from torch import nn from torch import nn
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
@ -41,14 +42,6 @@ from text_generation_server.utils.layers import (
get_linear, get_linear,
) )
HAS_MEGABLOCKS = True
try:
import stk
import megablocks.ops as ops
except ImportError:
logger.warning("Mixtral: megablocks is not installed")
HAS_MEGABLOCKS = False
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
model_type = "mixtral" model_type = "mixtral"
@ -229,10 +222,6 @@ class MixtralAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -292,7 +281,7 @@ class MixtralAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -321,18 +310,6 @@ def round_up(x: torch.Tensor, value: int):
class BlockSparseMoE(nn.Module): class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: MixtralConfig, weights): def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__() super().__init__()
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
@ -357,236 +334,40 @@ class BlockSparseMoE(nn.Module):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights) w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) self.num_experts, self.ffn_dim, self.hidden_dim
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights) )
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.offsets = None self.num_experts, self.ffn_dim, self.hidden_dim
self.offsets_block_rows = 0 )
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices def forward(self, x: torch.Tensor) -> torch.Tensor:
# so that we can pass it to radix sort. # router_logits: (num_tokens, n_experts)
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) router_logits = self.gate(x)
self.blocking = 128 out = fused_moe(
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(gate_logits, self.top_k)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
x, x,
indices, self.w13,
bin_ids, self.w2,
weights, router_logits,
bins,
padded_bins,
self.top_k, self.top_k,
self.quantize_scatter_num_bits, renormalize=True,
).view(*input_shape) inplace=True,
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
) )
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out.view(*x.shape)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
class DenseMoE(nn.Module): class DenseMoE(nn.Module):

View File

@ -120,9 +120,6 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
@ -168,7 +165,7 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -140,10 +140,6 @@ class FlashPhiAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=True, bias=True,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -206,7 +202,7 @@ class FlashPhiAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -104,10 +104,6 @@ class Qwen2Attention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -167,7 +163,7 @@ class Qwen2Attention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -151,15 +151,6 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -213,7 +204,7 @@ class FlashRWAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_heads_kv,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -272,10 +263,6 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -332,7 +319,7 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_groups,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -241,9 +241,6 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
@ -292,7 +289,7 @@ class FlashMQAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -190,9 +190,6 @@ class Starcoder2Attention(torch.nn.Module):
bias=config.use_bias, bias=config.use_bias,
) )
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -252,7 +249,7 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.num_key_value_heads,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -690,7 +690,7 @@ class FlashCausalLM(Model):
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = ( block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device) torch.arange(max_bt, dtype=torch.int32, device=self.device)

View File

@ -378,7 +378,7 @@ class BaseFlashMistral(FlashCausalLM):
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = ( block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device) torch.arange(max_bt, dtype=torch.int32, device=self.device)

View File

@ -88,6 +88,9 @@ def attention(
out, out,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
None,
None,
None,
max_s, max_s,
max_s, max_s,
0.0, 0.0,

View File

@ -174,6 +174,8 @@ class EETQLinear(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
device = weight.device device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu() weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False) weight, scale = quant_weights(weight, torch.int8, False)

View File

@ -1,8 +1,7 @@
import torch import torch
# vllm imports # vllm imports
from vllm import cache_ops from vllm._C import cache_ops, ops
from vllm import attention_ops
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -14,7 +13,7 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def attention( def attention(
@ -22,7 +21,7 @@ def attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor, num_key_value_heads: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
@ -54,20 +53,22 @@ def attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
attention_ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
key_cache, key_cache,
value_cache, value_cache,
kv_head_mapping, num_key_value_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, block_size,
max_s, max_s,
None, None,
"auto",
1.0,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -83,7 +84,7 @@ def attention(
device=out.device, device=out.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2( ops.paged_attention_v2(
out, out,
exp_sums, exp_sums,
max_logits, max_logits,
@ -91,11 +92,13 @@ def attention(
query, query,
key_cache, key_cache,
value_cache, value_cache,
kv_head_mapping, num_key_value_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, block_size,
max_s, max_s,
None, None,
"auto",
1.0,
) )