mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
applied mla to deepseek v2
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
778b61c0da
commit
433029e56f
@ -28,11 +28,12 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
Fp8Linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -44,6 +45,18 @@ from text_generation_server.utils.weights import Weights
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
||||
if isinstance(layer, Fp8Linear):
|
||||
eye = torch.eye(
|
||||
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
|
||||
)
|
||||
dequant_weights = layer(eye)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
@ -246,6 +259,45 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.value_head_size,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||
q_nope, q_pe = (
|
||||
q_proj(x)
|
||||
.view(-1, self.num_heads, self.head_size)
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
|
||||
return self.o_proj(x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -258,14 +310,9 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
query = self.q_proj(hidden_states)
|
||||
hidden_states_or_q_c = hidden_states
|
||||
else:
|
||||
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
_, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, key_pe = torch.split(
|
||||
@ -273,13 +320,18 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
)
|
||||
|
||||
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||
)
|
||||
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||
query = q_proj(hidden_states_or_q_c)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
query_nope, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
else:
|
||||
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
|
||||
batch_size, heads, head_dim = query_pe.shape
|
||||
query_pe = (
|
||||
@ -294,33 +346,47 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
.reshape(batch_size, heads, head_dim)
|
||||
)
|
||||
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||
latent_vec_k = torch.concat(
|
||||
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
|
||||
)
|
||||
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
|
||||
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
|
||||
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
key=latent_vec_k,
|
||||
value=None,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
kv = self.kv_b_proj(kv_c_normed).view(
|
||||
-1,
|
||||
self.num_key_value_heads,
|
||||
self.qk_nope_head_dim + self.value_head_size,
|
||||
)
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
@ -331,9 +397,15 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
seqlen=seqlen,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
# Decode
|
||||
query = torch.cat([query_nope, query_pe], dim=-1)
|
||||
attn_output = paged_attention_mla(
|
||||
query,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
@ -341,14 +413,10 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
attn_output = self._v_up_proj_and_o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
@ -1606,7 +1606,7 @@ class FlashCausalLM(Model):
|
||||
):
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
if self.config.model_type == "deepseek_v3":
|
||||
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
|
||||
self.kv_cache = [
|
||||
KVCompressCache(
|
||||
num_blocks=num_blocks,
|
||||
@ -1646,7 +1646,7 @@ class FlashCausalLM(Model):
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||
if self.config.model_type == "deepseek_v3":
|
||||
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
|
||||
cache_block_size = BLOCK_SIZE * (
|
||||
self.config.kv_lora_rank + self.config.qk_rope_head_dim
|
||||
)
|
||||
|
@ -184,6 +184,7 @@ Text Generation Inference enables serving optimized models on Gaudi hardware. Th
|
||||
|
||||
**Large Language Models (LLMs)**
|
||||
- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)
|
||||
- [deepseek-v2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||
- [Llama2](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b)
|
||||
- [Llama3](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||
- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)
|
||||
|
Loading…
Reference in New Issue
Block a user