Mla deepspeek (#2)

* mla optimization

* hpu need padding in the first token generation


Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-13 22:42:46 +08:00 committed by GitHub
parent f728cf69f2
commit b2bd163d19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 274 additions and 65 deletions

View File

@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.lora import (
LoraLinear,
@ -27,6 +28,7 @@ __all__ = [
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",

View File

@ -10,18 +10,21 @@ from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
paged_attention_mla,
)
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",

View File

@ -117,7 +117,7 @@ def paged_attention(
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
@ -138,8 +138,39 @@ def paged_attention(
return output.view(batch_size, head_num, head_size)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]
def paged_attention_mla(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]

View File

@ -108,6 +108,69 @@ class KVCache:
)
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks, BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,

View File

@ -28,11 +28,12 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
Fp8Linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
paged_attention_mla,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -42,6 +43,18 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_ms
from text_generation_server.utils.weights import Weights
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
@ -249,6 +262,44 @@ class DeepseekV3Attention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward(
self,
hidden_states: torch.Tensor,
@ -261,14 +312,9 @@ class DeepseekV3Attention(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
hidden_states_or_q_c = hidden_states
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
@ -276,13 +322,18 @@ class DeepseekV3Attention(torch.nn.Module):
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
# Prefill
if cu_seqlen_prefill is not None:
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
@ -297,7 +348,30 @@ class DeepseekV3Attention(torch.nn.Module):
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
kv_cache.store(
key=latent_vec_k,
value=None,
slots=slots,
kv_scales=self.kv_scales,
)
if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
@ -315,15 +389,6 @@ class DeepseekV3Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
@ -334,9 +399,15 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else:
attn_output = paged_attention(
# Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query,
kv_cache,
self.kv_head_mapping,
@ -344,14 +415,10 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
attn_output = self._v_up_proj_and_o_proj(attn_output)
return attn_output
class DeepseekV3MLP(nn.Module):

View File

@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
)
from text_generation_server.layers.attention import (
KVCache,
KVCompressCache,
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
@ -68,7 +69,9 @@ from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from text_generation_server.utils.prefill_chunking import (
get_max_prefill_tokens,
)
import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
@ -1482,6 +1485,17 @@ class FlashCausalLM(Model):
):
self.kv_cache = []
empty_cache()
if self.config.model_type == "deepseek_v3":
self.kv_cache = [
KVCompressCache(
num_blocks=num_blocks,
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
@ -1511,8 +1525,14 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
if self.config.model_type == "deepseek_v3":
cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)
else:
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
try:
self.init_kv_cache(
@ -1572,7 +1592,7 @@ class FlashCausalLM(Model):
self.kv_cache_dtype,
self.device,
)
self.max_batch_prefill_tokens = max_input_tokens * len(batch)
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context()
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
@ -1589,7 +1609,7 @@ class FlashCausalLM(Model):
max_blocks = max(
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
)
self.bucketing_ctx.num_hpu_blocks = max_blocks
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
@ -1616,6 +1636,8 @@ class FlashCausalLM(Model):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)

View File

@ -350,6 +350,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)

View File

@ -8,6 +8,7 @@ use std::cmp::max;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::usage_stats::Env;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
@ -15,7 +16,6 @@ use text_generation_router::validation::{
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
@ -185,6 +185,9 @@ struct State {
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
is_hpu_device: bool,
}
impl State {
@ -214,6 +217,7 @@ impl State {
speculate,
support_chunking,
block_allocator,
is_hpu_device: Env::new().is_hpu_device(),
}
}
@ -368,6 +372,21 @@ impl State {
}
}
//HPU padding for the prefill
if self.is_hpu_device {
max_input_length = max_input_length.max(entry.request.input_length);
let actual_prefill_tokens_for_hpu =
(batch.len() + 1) as u32 * max_input_length;
if actual_prefill_tokens_for_hpu > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
Some(block_allocation)