mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Mla deepspeek (#2)
* mla optimization * hpu need padding in the first token generation Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f728cf69f2
commit
b2bd163d19
@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
|||||||
# Just to add the `load` methods.
|
# Just to add the `load` methods.
|
||||||
from text_generation_server.layers.layernorm import load_layer_norm
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
from text_generation_server.layers.conv import load_conv2d
|
from text_generation_server.layers.conv import load_conv2d
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Linear
|
||||||
|
|
||||||
from text_generation_server.layers.lora import (
|
from text_generation_server.layers.lora import (
|
||||||
LoraLinear,
|
LoraLinear,
|
||||||
@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"TensorParallelEmbedding",
|
"TensorParallelEmbedding",
|
||||||
"SpeculativeHead",
|
"SpeculativeHead",
|
||||||
"LoraLinear",
|
"LoraLinear",
|
||||||
|
"Fp8Linear",
|
||||||
"TensorParallelMultiAdapterLinear",
|
"TensorParallelMultiAdapterLinear",
|
||||||
"TensorParallelAdapterRowLinear",
|
"TensorParallelAdapterRowLinear",
|
||||||
"load_layer_norm",
|
"load_layer_norm",
|
||||||
|
@ -10,18 +10,21 @@ from .hpu import (
|
|||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
paged_attention_mla,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
from .kv_cache import KVCache, get_kv_scales
|
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
"get_kv_scales",
|
"get_kv_scales",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
|
"paged_attention_mla",
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
|
"KVCompressCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
"HPUPagedAttentionMetadata",
|
"HPUPagedAttentionMetadata",
|
||||||
"trim_seqlen_metadata",
|
"trim_seqlen_metadata",
|
||||||
|
@ -117,7 +117,7 @@ def paged_attention(
|
|||||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||||
):
|
):
|
||||||
batch_size, head_num, head_size = query.shape
|
batch_size, head_num, head_size = query.shape
|
||||||
fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn
|
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||||
output = ops.flat_pa(
|
output = ops.flat_pa(
|
||||||
query=query.view(batch_size, 1, head_num * head_size),
|
query=query.view(batch_size, 1, head_num * head_size),
|
||||||
key_cache=kv_cache.key,
|
key_cache=kv_cache.key,
|
||||||
@ -138,8 +138,39 @@ def paged_attention(
|
|||||||
return output.view(batch_size, head_num, head_size)
|
return output.view(batch_size, head_num, head_size)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
def paged_attention_mla(
|
||||||
"SUPPORTS_WINDOWING",
|
query: torch.Tensor,
|
||||||
"attention",
|
kv_cache: KVCache,
|
||||||
"paged_attention",
|
kv_head_mapping: torch.Tensor,
|
||||||
]
|
softmax_scale: float,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||||
|
kv_lora_rank: int = 0,
|
||||||
|
):
|
||||||
|
batch_size, head_num, head_size = query.shape
|
||||||
|
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||||
|
output = ops.flat_pa_mla(
|
||||||
|
query=query,
|
||||||
|
key_cache=kv_cache.key,
|
||||||
|
value_cache=None,
|
||||||
|
block_list=hpu_attention_meta.block_list,
|
||||||
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
|
scale=softmax_scale,
|
||||||
|
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||||
|
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||||
|
batch2block_matmul_op=Matmul(),
|
||||||
|
block2batch_matmul_op=Matmul(),
|
||||||
|
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||||
|
values_fetch_func=None,
|
||||||
|
kv_lora_rank=kv_lora_rank,
|
||||||
|
)
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(batch_size, head_num, -1)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|
||||||
|
@ -108,6 +108,69 @@ class KVCache:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KVCompressCache(KVCache):
|
||||||
|
"""
|
||||||
|
Key-value cache for attention layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_blocks: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Construct the key-value cache for a layer."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
if dtype is torch.float8_e5m2:
|
||||||
|
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||||
|
|
||||||
|
self.kv_cache = torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, 1, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the data type of the cache."""
|
||||||
|
return self.kv_cache.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
"""Get the key cache."""
|
||||||
|
|
||||||
|
return self.kv_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
"""Get the value cache."""
|
||||||
|
|
||||||
|
return self.kv_cache
|
||||||
|
|
||||||
|
def store(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
):
|
||||||
|
"""Store the key and value at the given slots."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
block_idx = slots // BLOCK_SIZE
|
||||||
|
block_offset = slots % BLOCK_SIZE
|
||||||
|
if self.kv_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
|
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
|
||||||
|
)[0]
|
||||||
|
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
|
||||||
|
|
||||||
|
|
||||||
def paged_reshape_and_cache(
|
def paged_reshape_and_cache(
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
@ -28,11 +28,12 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
Fp8Linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention_mla,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
@ -42,6 +43,18 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_ms
|
|||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
||||||
|
if isinstance(layer, Fp8Linear):
|
||||||
|
eye = torch.eye(
|
||||||
|
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
|
||||||
|
)
|
||||||
|
dequant_weights = layer(eye)
|
||||||
|
del eye
|
||||||
|
# standardize to (output, input)
|
||||||
|
return dequant_weights.T
|
||||||
|
return layer.weight
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Config(PretrainedConfig):
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -249,6 +262,44 @@ class DeepseekV3Attention(torch.nn.Module):
|
|||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
|
||||||
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.value_head_size,
|
||||||
|
)
|
||||||
|
W_UK, W_UV = kv_b_proj_weight.split(
|
||||||
|
[self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
# Convert from (L, N, V) to (N, L, V)
|
||||||
|
self.W_UV = W_UV.transpose(0, 1)
|
||||||
|
# Convert from (L, N, P) to (N, P, L)
|
||||||
|
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||||
|
|
||||||
|
def _q_proj_and_k_up_proj(self, x):
|
||||||
|
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||||
|
q_nope, q_pe = (
|
||||||
|
q_proj(x)
|
||||||
|
.view(-1, self.num_heads, self.head_size)
|
||||||
|
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
|
q_nope = q_nope.transpose(0, 1)
|
||||||
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
|
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||||
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
|
return ql_nope.transpose(0, 1), q_pe
|
||||||
|
|
||||||
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
|
x = torch.bmm(x, self.W_UV)
|
||||||
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
|
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
return self.o_proj(x)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -261,14 +312,9 @@ class DeepseekV3Attention(torch.nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
if self.q_lora_rank is None:
|
if self.q_lora_rank is None:
|
||||||
query = self.q_proj(hidden_states)
|
hidden_states_or_q_c = hidden_states
|
||||||
else:
|
else:
|
||||||
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
_, query_pe = torch.split(
|
|
||||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
compressed_kv, key_pe = torch.split(
|
compressed_kv, key_pe = torch.split(
|
||||||
@ -276,13 +322,18 @@ class DeepseekV3Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
|
||||||
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
key_nope, value = torch.split(
|
# Prefill
|
||||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
if cu_seqlen_prefill is not None:
|
||||||
)
|
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||||
|
query = q_proj(hidden_states_or_q_c)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
query_nope, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||||
|
|
||||||
batch_size, heads, head_dim = query_pe.shape
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
query_pe = (
|
query_pe = (
|
||||||
@ -297,33 +348,47 @@ class DeepseekV3Attention(torch.nn.Module):
|
|||||||
.reshape(batch_size, heads, head_dim)
|
.reshape(batch_size, heads, head_dim)
|
||||||
)
|
)
|
||||||
self.rotary_emb(query_pe, key_pe, cos, sin)
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
latent_vec_k = torch.concat(
|
||||||
|
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
|
||||||
|
)
|
||||||
|
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
|
||||||
|
|
||||||
query[..., self.qk_nope_head_dim :] = query_pe
|
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
|
||||||
key = torch.empty_like(query)
|
|
||||||
key[..., : self.qk_nope_head_dim] = key_nope
|
|
||||||
key[..., self.qk_nope_head_dim :] = key_pe
|
|
||||||
|
|
||||||
# We need to pad the heads because Flash Attention does not support
|
|
||||||
# qk and v with different head sizes.
|
|
||||||
query = torch.nn.functional.pad(
|
|
||||||
query, (0, self.head_pad_size - self.head_size), value=0
|
|
||||||
)
|
|
||||||
key = torch.nn.functional.pad(
|
|
||||||
key, (0, self.head_pad_size - self.head_size), value=0
|
|
||||||
)
|
|
||||||
value = torch.nn.functional.pad(
|
|
||||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
|
||||||
)
|
|
||||||
|
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=key,
|
key=latent_vec_k,
|
||||||
value=value,
|
value=None,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
|
kv = self.kv_b_proj(kv_c_normed).view(
|
||||||
|
-1,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.qk_nope_head_dim + self.value_head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query=query,
|
query=query,
|
||||||
@ -334,9 +399,15 @@ class DeepseekV3Attention(torch.nn.Module):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
# Decode
|
||||||
|
query = torch.cat([query_nope, query_pe], dim=-1)
|
||||||
|
attn_output = paged_attention_mla(
|
||||||
query,
|
query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
@ -344,14 +415,10 @@ class DeepseekV3Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
)
|
)
|
||||||
|
attn_output = self._v_up_proj_and_o_proj(attn_output)
|
||||||
# Remove padding.
|
return attn_output
|
||||||
attn_output = attn_output[..., : self.value_head_size]
|
|
||||||
|
|
||||||
return self.o_proj(
|
|
||||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MLP(nn.Module):
|
class DeepseekV3MLP(nn.Module):
|
||||||
|
@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
KVCache,
|
KVCache,
|
||||||
|
KVCompressCache,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
trim_attn_metadata,
|
trim_attn_metadata,
|
||||||
@ -68,7 +69,9 @@ from text_generation_server.utils.import_utils import (
|
|||||||
synchronize,
|
synchronize,
|
||||||
get_free_memory,
|
get_free_memory,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.prefill_chunking import (
|
||||||
|
get_max_prefill_tokens,
|
||||||
|
)
|
||||||
import vllm_hpu_extension.environment as environment
|
import vllm_hpu_extension.environment as environment
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
import itertools
|
import itertools
|
||||||
@ -1482,16 +1485,27 @@ class FlashCausalLM(Model):
|
|||||||
):
|
):
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
self.kv_cache = [
|
if self.config.model_type == "deepseek_v3":
|
||||||
KVCache(
|
self.kv_cache = [
|
||||||
num_blocks=num_blocks,
|
KVCompressCache(
|
||||||
num_heads=num_heads,
|
num_blocks=num_blocks,
|
||||||
head_size=head_size,
|
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
self.kv_cache = [
|
||||||
|
KVCache(
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
def warmup(
|
def warmup(
|
||||||
self,
|
self,
|
||||||
@ -1511,8 +1525,14 @@ class FlashCausalLM(Model):
|
|||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
if self.config.model_type == "deepseek_v3":
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
cache_block_size = BLOCK_SIZE * (
|
||||||
|
self.config.kv_lora_rank + self.config.qk_rope_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
|
cache_block_size = cache_block_size * 2
|
||||||
|
total_cache_size = self.num_layers * cache_block_size * dtype_size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
@ -1572,7 +1592,7 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
self.max_batch_prefill_tokens = max_input_tokens * len(batch)
|
self.max_batch_prefill_tokens = get_max_prefill_tokens()
|
||||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
||||||
HPUBucketingContext = get_bucketing_context()
|
HPUBucketingContext = get_bucketing_context()
|
||||||
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
|
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
|
||||||
@ -1589,7 +1609,7 @@ class FlashCausalLM(Model):
|
|||||||
max_blocks = max(
|
max_blocks = max(
|
||||||
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
|
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
|
||||||
)
|
)
|
||||||
self.bucketing_ctx.num_hpu_blocks = max_blocks
|
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
|
||||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||||
self.bucketing_ctx.generate_prompt_buckets()
|
self.bucketing_ctx.generate_prompt_buckets()
|
||||||
self.bucketing_ctx.generate_decode_buckets(
|
self.bucketing_ctx.generate_decode_buckets(
|
||||||
@ -1616,6 +1636,8 @@ class FlashCausalLM(Model):
|
|||||||
for i, (batch_size, seq_len) in enumerate(
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
reversed(self.bucketing_ctx.prompt_buckets)
|
reversed(self.bucketing_ctx.prompt_buckets)
|
||||||
):
|
):
|
||||||
|
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||||
|
continue
|
||||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size, batch)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
|
@ -350,6 +350,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
for i, (batch_size, seq_len) in enumerate(
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
reversed(self.bucketing_ctx.prompt_buckets)
|
reversed(self.bucketing_ctx.prompt_buckets)
|
||||||
):
|
):
|
||||||
|
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||||
|
continue
|
||||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size, batch)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
|
@ -8,6 +8,7 @@ use std::cmp::max;
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::usage_stats::Env;
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||||
ValidStoppingParameters,
|
ValidStoppingParameters,
|
||||||
@ -15,7 +16,6 @@ use text_generation_router::validation::{
|
|||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
/// Queue entry
|
/// Queue entry
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct Entry {
|
pub(crate) struct Entry {
|
||||||
@ -185,6 +185,9 @@ struct State {
|
|||||||
|
|
||||||
/// Paged Attention Block Allocation
|
/// Paged Attention Block Allocation
|
||||||
block_allocator: Option<BlockAllocator>,
|
block_allocator: Option<BlockAllocator>,
|
||||||
|
|
||||||
|
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
|
||||||
|
is_hpu_device: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
@ -214,6 +217,7 @@ impl State {
|
|||||||
speculate,
|
speculate,
|
||||||
support_chunking,
|
support_chunking,
|
||||||
block_allocator,
|
block_allocator,
|
||||||
|
is_hpu_device: Env::new().is_hpu_device(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,6 +372,21 @@ impl State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//HPU padding for the prefill
|
||||||
|
if self.is_hpu_device {
|
||||||
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
|
let actual_prefill_tokens_for_hpu =
|
||||||
|
(batch.len() + 1) as u32 * max_input_length;
|
||||||
|
|
||||||
|
if actual_prefill_tokens_for_hpu > prefill_token_budget {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
prefill_tokens += postfix_len;
|
prefill_tokens += postfix_len;
|
||||||
|
|
||||||
Some(block_allocation)
|
Some(block_allocation)
|
||||||
|
Loading…
Reference in New Issue
Block a user