mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: fix CohereForAI/c4ai-command-r-plus
This commit is contained in:
parent
106d8ee818
commit
5088005908
@ -160,11 +160,6 @@ WORKDIR /usr/src
|
||||
COPY server/Makefile-selective-scan Makefile
|
||||
RUN make build-all
|
||||
|
||||
# Build megablocks
|
||||
FROM kernel-builder as megablocks-builder
|
||||
|
||||
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
|
||||
|
||||
@ -186,9 +181,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy conda with PyTorch and Megablocks installed
|
||||
COPY --from=megablocks-builder /opt/conda /opt/conda
|
||||
|
||||
# Copy build artifacts from flash attention builder
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
@ -17,9 +17,6 @@ gen-server:
|
||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation_server/pb/__init__.py
|
||||
|
||||
install-megablocks:
|
||||
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||
|
||||
install: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements_cuda.txt
|
||||
|
@ -1,4 +1,4 @@
|
||||
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ vllm-cuda:
|
||||
git clone https://github.com/vllm-project/vllm.git vllm
|
||||
|
||||
build-vllm-cuda: vllm-cuda
|
||||
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
cd vllm && git fetch && git checkout b7782002e1da25de77e0b1890ff8b72dd4df917c
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm-cuda: build-vllm-cuda
|
||||
|
@ -43,7 +43,7 @@ class CacheManager:
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
0, num_blocks * self.block_size, dtype=torch.int64
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(
|
||||
|
@ -61,6 +61,7 @@ class CohereConfig(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
logit_scale=1.0,
|
||||
use_qk_norm=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -84,6 +85,7 @@ class CohereConfig(PretrainedConfig):
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.logit_scale = logit_scale
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -175,16 +177,28 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.q_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.k_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.k_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -199,21 +213,25 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
query, key, value = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
if self.use_qk_norm:
|
||||
query = self.q_norm(query)
|
||||
key = self.k_norm(key)
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
@ -223,8 +241,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
@ -235,9 +253,9 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
key,
|
||||
value,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -16,14 +16,13 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
@ -37,14 +36,6 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
import stk
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
logger.warning("Dbrx: megablocks is not installed")
|
||||
HAS_MEGABLOCKS = False
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
@ -384,10 +375,6 @@ class DbrxAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -443,7 +430,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
@ -531,18 +518,6 @@ def round_up(x: torch.Tensor, value: int):
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
"""
|
||||
Built on the paper and library Megablocks as described in
|
||||
https://arxiv.org/abs/2211.15841. This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accomodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||
super().__init__()
|
||||
self.moe_normalize_expert_weights = (
|
||||
@ -572,241 +547,40 @@ class BlockSparseMoE(nn.Module):
|
||||
)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights)
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.wv1 = torch.cat([w1, v1], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
# Calculate the number of bits needed to represent the expert indices
|
||||
# so that we can pass it to radix sort.
|
||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||
self.blocking = 128
|
||||
self.quantize_scatter_num_bits = -1
|
||||
|
||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||
padded_tokens, _ = x.size()
|
||||
assert padded_tokens % self.blocking == 0
|
||||
assert self.ffn_dim % self.blocking == 0
|
||||
|
||||
# Offsets for the sparse matrix. All rows have the
|
||||
# same number of nonzero blocks dictated by the
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[: block_rows + 1]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
# on the mapping of tokens to experts.
|
||||
column_indices = ops.topology(
|
||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||
)
|
||||
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
self.blocking,
|
||||
self.blocking,
|
||||
dtype=x.dtype,
|
||||
device="meta",
|
||||
)
|
||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||
return stk.Matrix(
|
||||
shape,
|
||||
data,
|
||||
row_indices,
|
||||
column_indices,
|
||||
offsets,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||
|
||||
# Histogram the expert ids to identify the number of
|
||||
# tokens routed to each expert.
|
||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||
|
||||
# Round the token counts up to the block size used in
|
||||
# the matrix muliplications. Caculate the starting
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||
padded_bins = promote_scalar(padded_bins)
|
||||
# padded_bins => [128, 128, 256, ...]
|
||||
|
||||
# Calculate the bin bounds for the sorted tokens.
|
||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||
bins = promote_scalar(bins)
|
||||
# bins => [3, 3, 5, ...]
|
||||
|
||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||
|
||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
selected_experts, weights = select_experts(
|
||||
gate_logits, self.top_k, self.moe_normalize_expert_weights
|
||||
)
|
||||
|
||||
(
|
||||
indices,
|
||||
bin_ids,
|
||||
bins,
|
||||
padded_bins,
|
||||
_,
|
||||
) = self.indices_and_padded_bins(selected_experts)
|
||||
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
topo = self.topology(x, padded_bins)
|
||||
|
||||
# Perform the expert computation
|
||||
# First Dense x Dense -> Sparse for w1 and v1,
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||
* stk.ops.sdd(x, self.v1.t(), topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
topo.column_indices_t,
|
||||
topo.offsets_t,
|
||||
topo.block_offsets_t,
|
||||
)
|
||||
|
||||
# Then Sparse x Dense -> Dense for w2
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = stk.ops.dsd(x, self.w2)
|
||||
|
||||
# Permute back and remove padding
|
||||
# (sequence_length, model_dim)
|
||||
x = ops.padded_scatter(
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
indices,
|
||||
bin_ids,
|
||||
weights,
|
||||
bins,
|
||||
padded_bins,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.quantize_scatter_num_bits,
|
||||
).view(*input_shape)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(x, group=self.process_group)
|
||||
|
||||
return x.view(*input_shape)
|
||||
|
||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
weights,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
weights.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
if self.moe_normalize_expert_weights:
|
||||
weights = weights / torch.norm(
|
||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||
)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
|
||||
# Permute to [num_experts, model_dim, ffn_dim]
|
||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
renormalize=self.moe_normalize_expert_weights,
|
||||
inplace=True,
|
||||
)
|
||||
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
|
||||
|
||||
out = torch.bmm(
|
||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
)
|
||||
# Mask not selected experts
|
||||
out *= weights.t().view(self.num_experts, -1, 1)
|
||||
|
||||
# Sum experts
|
||||
out = out.sum(0)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||
return self.sparse_forward(x)
|
||||
# This is faster when there is not a lot of tokens
|
||||
return self.dense_forward(x)
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
|
@ -188,10 +188,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -244,7 +240,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -176,10 +176,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -232,7 +228,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -173,10 +173,6 @@ class MistralAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -236,7 +232,7 @@ class MistralAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -24,6 +24,7 @@ import torch.distributed
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
@ -41,14 +42,6 @@ from text_generation_server.utils.layers import (
|
||||
get_linear,
|
||||
)
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
import stk
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
logger.warning("Mixtral: megablocks is not installed")
|
||||
HAS_MEGABLOCKS = False
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
model_type = "mixtral"
|
||||
@ -229,10 +222,6 @@ class MixtralAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -292,7 +281,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
@ -321,18 +310,6 @@ def round_up(x: torch.Tensor, value: int):
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
"""
|
||||
Built on the paper and library Megablocks as described in
|
||||
https://arxiv.org/abs/2211.15841. This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accomodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
@ -357,236 +334,40 @@ class BlockSparseMoE(nn.Module):
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
# Calculate the number of bits needed to represent the expert indices
|
||||
# so that we can pass it to radix sort.
|
||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||
self.blocking = 128
|
||||
self.quantize_scatter_num_bits = -1
|
||||
|
||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||
padded_tokens, _ = x.size()
|
||||
assert padded_tokens % self.blocking == 0
|
||||
assert self.ffn_dim % self.blocking == 0
|
||||
|
||||
# Offsets for the sparse matrix. All rows have the
|
||||
# same number of nonzero blocks dictated by the
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[: block_rows + 1]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
# on the mapping of tokens to experts.
|
||||
column_indices = ops.topology(
|
||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||
)
|
||||
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
self.blocking,
|
||||
self.blocking,
|
||||
dtype=x.dtype,
|
||||
device="meta",
|
||||
)
|
||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||
return stk.Matrix(
|
||||
shape,
|
||||
data,
|
||||
row_indices,
|
||||
column_indices,
|
||||
offsets,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||
|
||||
# Histogram the expert ids to identify the number of
|
||||
# tokens routed to each expert.
|
||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||
|
||||
# Round the token counts up to the block size used in
|
||||
# the matrix muliplications. Caculate the starting
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||
padded_bins = promote_scalar(padded_bins)
|
||||
# padded_bins => [128, 128, 256, ...]
|
||||
|
||||
# Calculate the bin bounds for the sorted tokens.
|
||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||
bins = promote_scalar(bins)
|
||||
# bins => [3, 3, 5, ...]
|
||||
|
||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||
|
||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
||||
|
||||
(
|
||||
indices,
|
||||
bin_ids,
|
||||
bins,
|
||||
padded_bins,
|
||||
_,
|
||||
) = self.indices_and_padded_bins(selected_experts)
|
||||
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
topo = self.topology(x, padded_bins)
|
||||
|
||||
# Perform the expert computation
|
||||
# First Dense x Dense -> Sparse for w1 and w3,
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||
* stk.ops.sdd(x, self.w3.t(), topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
topo.column_indices_t,
|
||||
topo.offsets_t,
|
||||
topo.block_offsets_t,
|
||||
)
|
||||
|
||||
# Then Sparse x Dense -> Dense for w2
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = stk.ops.dsd(x, self.w2)
|
||||
|
||||
# Permute back and remove padding
|
||||
# (sequence_length, model_dim)
|
||||
x = ops.padded_scatter(
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
indices,
|
||||
bin_ids,
|
||||
weights,
|
||||
bins,
|
||||
padded_bins,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.quantize_scatter_num_bits,
|
||||
).view(*input_shape)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(x, group=self.process_group)
|
||||
|
||||
return x.view(*input_shape)
|
||||
|
||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
all_probs,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
all_probs.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
|
||||
# Permute to [num_experts, model_dim, ffn_dim]
|
||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
|
||||
out = torch.bmm(
|
||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
)
|
||||
# Mask not selected experts
|
||||
out *= weights.t().view(self.num_experts, -1, 1)
|
||||
|
||||
# Sum experts
|
||||
out = out.sum(0)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||
return self.sparse_forward(x)
|
||||
# This is faster when there is not a lot of tokens
|
||||
return self.dense_forward(x)
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
|
@ -120,9 +120,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -168,7 +165,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -140,10 +140,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -206,7 +202,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -104,10 +104,6 @@ class Qwen2Attention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -167,7 +163,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -151,15 +151,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
if self.num_heads_kv == 1:
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
else:
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -213,7 +204,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_heads_kv,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
@ -272,10 +263,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_groups, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -332,7 +319,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_groups,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -241,9 +241,6 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.c_proj = load_row(
|
||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||
)
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -292,7 +289,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -190,9 +190,6 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -252,7 +249,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.num_key_value_heads,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -690,7 +690,7 @@ class FlashCausalLM(Model):
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
|
@ -378,7 +378,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
|
@ -88,6 +88,9 @@ def attention(
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -174,6 +174,8 @@ class EETQLinear(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
if weight.dtype != torch.float16:
|
||||
weight = weight.to(dtype=torch.float16)
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
import torch
|
||||
|
||||
# vllm imports
|
||||
from vllm import cache_ops
|
||||
from vllm import attention_ops
|
||||
from vllm._C import cache_ops, ops
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
@ -14,7 +13,7 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def attention(
|
||||
@ -22,7 +21,7 @@ def attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
num_key_value_heads: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
@ -54,20 +53,22 @@ def attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
attention_ops.paged_attention_v1(
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
num_key_value_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
@ -83,7 +84,7 @@ def attention(
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
@ -91,11 +92,13 @@ def attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
num_key_value_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user