diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py index 0000ca91..fd146728 100644 --- a/backends/gaudi/server/text_generation_server/layers/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/__init__.py @@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d +from text_generation_server.layers.fp8 import Fp8Linear from text_generation_server.layers.lora import ( LoraLinear, @@ -27,6 +28,7 @@ __all__ = [ "TensorParallelEmbedding", "SpeculativeHead", "LoraLinear", + "Fp8Linear", "TensorParallelMultiAdapterLinear", "TensorParallelAdapterRowLinear", "load_layer_norm", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 89a43d65..370e05bc 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -10,18 +10,21 @@ from .hpu import ( SUPPORTS_WINDOWING, attention, paged_attention, + paged_attention_mla, ) # KVCache needs `reshape_and_cache`, so ensure that it is defined already. -from .kv_cache import KVCache, get_kv_scales +from .kv_cache import KVCache, get_kv_scales, KVCompressCache __all__ = [ "attention", "get_kv_scales", "paged_attention", + "paged_attention_mla", "SUPPORTS_WINDOWING", "KVCache", + "KVCompressCache", "Seqlen", "HPUPagedAttentionMetadata", "trim_seqlen_metadata", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 092fe138..1c2e37c7 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -117,7 +117,7 @@ def paged_attention( hpu_attention_meta: HPUPagedAttentionMetadata, ): batch_size, head_num, head_size = query.shape - fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn output = ops.flat_pa( query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, @@ -138,8 +138,39 @@ def paged_attention( return output.view(batch_size, head_num, head_size) -__all__ = [ - "SUPPORTS_WINDOWING", - "attention", - "paged_attention", -] +def paged_attention_mla( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, + kv_lora_rank: int = 0, +): + batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn + output = ops.flat_pa_mla( + query=query, + key_cache=kv_cache.key, + value_cache=None, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_groups=hpu_attention_meta.block_groups, + scale=softmax_scale, + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=None, + kv_lora_rank=kv_lora_rank, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, -1) + + +__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index e6c5f67d..cdd1e1d7 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -108,6 +108,69 @@ class KVCache: ) +class KVCompressCache(KVCache): + """ + Key-value cache for attention layers. + """ + + kv_cache: torch.Tensor + + def __init__( + self, + *, + num_blocks: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") + + self.kv_cache = torch.zeros( + (num_blocks, BLOCK_SIZE, 1, head_size), + dtype=dtype, + device=device, + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache.dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + + block_idx = slots // BLOCK_SIZE + block_offset = slots % BLOCK_SIZE + if self.kv_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, kv_scales.key_scale, False, False, torch.float8_e4m3fn + )[0] + cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset) + + def paged_reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 1a7ce5cf..f6620d51 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -28,11 +28,12 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, get_linear, + Fp8Linear, ) from text_generation_server.layers.attention import ( Seqlen, attention, - paged_attention, + paged_attention_mla, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -42,6 +43,18 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_ms from text_generation_server.utils.weights import Weights +def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor: + if isinstance(layer, Fp8Linear): + eye = torch.eye( + layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device + ) + dequant_weights = layer(eye) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + class DeepseekV3Config(PretrainedConfig): def __init__( self, @@ -249,6 +262,44 @@ class DeepseekV3Attention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _q_proj_and_k_up_proj(self, x): + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + q_nope, q_pe = ( + q_proj(x) + .view(-1, self.num_heads, self.head_size) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def _v_up_proj_and_o_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size) + return self.o_proj(x) + def forward( self, hidden_states: torch.Tensor, @@ -261,14 +312,9 @@ class DeepseekV3Attention(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: - query = self.q_proj(hidden_states) + hidden_states_or_q_c = hidden_states else: - query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) - query = query.view(-1, self.num_heads, self.head_size) - - _, query_pe = torch.split( - query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, key_pe = torch.split( @@ -276,13 +322,18 @@ class DeepseekV3Attention(torch.nn.Module): ) key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( - -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size - ) + kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0] - key_nope, value = torch.split( - kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 - ) + # Prefill + if cu_seqlen_prefill is not None: + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + query = q_proj(hidden_states_or_q_c) + query = query.view(-1, self.num_heads, self.head_size) + query_nope, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + else: + query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c) batch_size, heads, head_dim = query_pe.shape query_pe = ( @@ -297,33 +348,47 @@ class DeepseekV3Attention(torch.nn.Module): .reshape(batch_size, heads, head_dim) ) self.rotary_emb(query_pe, key_pe, cos, sin) + latent_vec_k = torch.concat( + (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1 + ) + latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank) - query[..., self.qk_nope_head_dim :] = query_pe - key = torch.empty_like(query) - key[..., : self.qk_nope_head_dim] = key_nope - key[..., self.qk_nope_head_dim :] = key_pe - - # We need to pad the heads because Flash Attention does not support - # qk and v with different head sizes. - query = torch.nn.functional.pad( - query, (0, self.head_pad_size - self.head_size), value=0 - ) - key = torch.nn.functional.pad( - key, (0, self.head_pad_size - self.head_size), value=0 - ) - value = torch.nn.functional.pad( - value, (0, self.head_pad_size - self.value_head_size), value=0 - ) + latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1)) kv_cache.store( - key=key, - value=value, + key=latent_vec_k, + value=None, slots=slots, kv_scales=self.kv_scales, ) - # Prefill if cu_seqlen_prefill is not None: + kv = self.kv_b_proj(kv_c_normed).view( + -1, + self.num_key_value_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + # flash attention attn_output = attention( query=query, @@ -334,9 +399,15 @@ class DeepseekV3Attention(torch.nn.Module): seqlen=seqlen, softmax_scale=self.softmax_scale, ) - # Decode + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) else: - attn_output = paged_attention( + # Decode + query = torch.cat([query_nope, query_pe], dim=-1) + attn_output = paged_attention_mla( query, kv_cache, self.kv_head_mapping, @@ -344,14 +415,10 @@ class DeepseekV3Attention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + kv_lora_rank=self.kv_lora_rank, ) - - # Remove padding. - attn_output = attn_output[..., : self.value_head_size] - - return self.o_proj( - attn_output.reshape(-1, self.num_heads * self.value_head_size) - ) + attn_output = self._v_up_proj_and_o_proj(attn_output) + return attn_output class DeepseekV3MLP(nn.Module): diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 4217c17b..8bbd46b5 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -53,6 +53,7 @@ from text_generation_server.models.globals import ( ) from text_generation_server.layers.attention import ( KVCache, + KVCompressCache, Seqlen, HPUPagedAttentionMetadata, trim_attn_metadata, @@ -68,7 +69,9 @@ from text_generation_server.utils.import_utils import ( synchronize, get_free_memory, ) - +from text_generation_server.utils.prefill_chunking import ( + get_max_prefill_tokens, +) import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools @@ -1482,16 +1485,27 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - self.kv_cache = [ - KVCache( - num_blocks=num_blocks, - num_heads=num_heads, - head_size=head_size, - dtype=dtype, - device=device, - ) - for _ in range(num_layers) - ] + if self.config.model_type == "deepseek_v3": + self.kv_cache = [ + KVCompressCache( + num_blocks=num_blocks, + head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] def warmup( self, @@ -1511,8 +1525,14 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + if self.config.model_type == "deepseek_v3": + cache_block_size = BLOCK_SIZE * ( + self.config.kv_lora_rank + self.config.qk_rope_head_dim + ) + else: + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + cache_block_size = cache_block_size * 2 + total_cache_size = self.num_layers * cache_block_size * dtype_size try: self.init_kv_cache( @@ -1572,7 +1592,7 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - self.max_batch_prefill_tokens = max_input_tokens * len(batch) + self.max_batch_prefill_tokens = get_max_prefill_tokens() max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) HPUBucketingContext = get_bucketing_context() max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE @@ -1589,7 +1609,7 @@ class FlashCausalLM(Model): max_blocks = max( BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE ) - self.bucketing_ctx.num_hpu_blocks = max_blocks + self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks) if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_decode_buckets( @@ -1616,6 +1636,8 @@ class FlashCausalLM(Model): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index f9186450..dac65fea 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -350,6 +350,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index d3bf4b9c..5bd4d03c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -8,6 +8,7 @@ use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::usage_stats::Env; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, @@ -15,7 +16,6 @@ use text_generation_router::validation::{ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; - /// Queue entry #[derive(Debug)] pub(crate) struct Entry { @@ -185,6 +185,9 @@ struct State { /// Paged Attention Block Allocation block_allocator: Option, + + /// indicate if it's hpu device, the hpu device needs padding to generate first token. + is_hpu_device: bool, } impl State { @@ -214,6 +217,7 @@ impl State { speculate, support_chunking, block_allocator, + is_hpu_device: Env::new().is_hpu_device(), } } @@ -368,6 +372,21 @@ impl State { } } + //HPU padding for the prefill + if self.is_hpu_device { + max_input_length = max_input_length.max(entry.request.input_length); + let actual_prefill_tokens_for_hpu = + (batch.len() + 1) as u32 * max_input_length; + + if actual_prefill_tokens_for_hpu > prefill_token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + prefill_tokens += postfix_len; Some(block_allocation)