set block mapping inside model graph

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-03 23:49:29 -07:00
parent 79ee5135e3
commit acc02aeb3e
22 changed files with 134 additions and 19 deletions

View File

@ -11,6 +11,7 @@ from .hpu import (
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
)
@ -22,6 +23,7 @@ __all__ = [
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",

View File

@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
import math
SUPPORTS_WINDOWING = False
@ -106,6 +107,21 @@ def attention(
return attn_output
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups, num_classes=batch_size
)
dtype = hpu_attention_meta.block_usage.dtype
device = hpu_attention_meta.block_usage.device
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
return hpu_attention_meta
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
@ -176,4 +192,10 @@ def paged_attention_mla(
return output.view(batch_size, head_num, -1)
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
]

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module):
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module):
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward

View File

@ -27,6 +27,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None

View File

@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward

View File

@ -43,6 +43,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.attention import (
KVCache,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -548,6 +549,10 @@ class Llama4TextModel(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
bs = seqlen.input_lengths.shape[0]

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -549,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer

View File

@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward

View File

@ -9,6 +9,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -266,7 +267,10 @@ class Qwen3Model(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(

View File

@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward

View File

@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module):
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,

View File

@ -153,19 +153,14 @@ def prepare_for_decode(
block_list_device = _async_h2d_tensor_copy(block_list)
block_groups_device = _async_h2d_tensor_copy(block_groups)
block_usage_device = _async_h2d_tensor_copy(block_usage)
block_mapping = torch.nn.functional.one_hot(
block_groups_device, num_classes=batch_size
)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage_device.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list_device,
block_groups=block_groups_device,
block_usage=block_usage_device,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
block_mapping=None,
attn_bias=None,
)
)
@ -1298,7 +1293,9 @@ class FlashCausalLMBatch(Batch):
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
all_input_ids_tensor = torch.zeros(
(max_padded_bs, max_total_tokens), dtype=torch.int64, device="hpu"
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
dtype=torch.int64,
device="hpu",
)
for i in range(len(self)):
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
@ -2051,8 +2048,6 @@ class FlashCausalLM(Model):
)
if batch.valid_indices is not None:
# TODO speculative decoding handling missing
next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu()
index = torch.arange(
0,
len(batch.valid_indices),
@ -2068,8 +2063,13 @@ class FlashCausalLM(Model):
0, index, next_input_ids[batch.valid_indices]
)
next_input_ids = next_input_ids[:padded_total_bs]
next_token_logprobs = next_token_logprobs[batch.valid_indices]
accepted_ids = accepted_ids[batch.valid_indices]
next_token_logprobs.index_copy_(
0, index, next_token_logprobs[batch.valid_indices]
)
accepted_ids.index_copy_(
0, index, accepted_ids[batch.valid_indices]
)
if speculative_ids is not None:
speculative_ids = speculative_ids[batch.valid_indices]
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[