mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
set block mapping inside model graph
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
79ee5135e3
commit
acc02aeb3e
@ -11,6 +11,7 @@ from .hpu import (
|
|||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
paged_attention_mla,
|
paged_attention_mla,
|
||||||
|
set_block_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ __all__ = [
|
|||||||
"get_kv_scales",
|
"get_kv_scales",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"paged_attention_mla",
|
"paged_attention_mla",
|
||||||
|
"set_block_mapping",
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"KVCompressCache",
|
"KVCompressCache",
|
||||||
|
@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
|||||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
import os
|
import os
|
||||||
from text_generation_server.models.globals import BLOCK_SIZE
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
import math
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -106,6 +107,21 @@ def attention(
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
|
||||||
|
block_mapping = torch.nn.functional.one_hot(
|
||||||
|
hpu_attention_meta.block_groups, num_classes=batch_size
|
||||||
|
)
|
||||||
|
dtype = hpu_attention_meta.block_usage.dtype
|
||||||
|
device = hpu_attention_meta.block_usage.device
|
||||||
|
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||||
|
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
|
||||||
|
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||||
|
hpu_attention_meta = hpu_attention_meta._replace(
|
||||||
|
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
|
||||||
|
)
|
||||||
|
return hpu_attention_meta
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: KVCache,
|
||||||
@ -176,4 +192,10 @@ def paged_attention_mla(
|
|||||||
return output.view(batch_size, head_num, -1)
|
return output.view(batch_size, head_num, -1)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|
__all__ = [
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
"paged_attention_mla",
|
||||||
|
"set_block_mapping",
|
||||||
|
]
|
||||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
seqlen: torch.Tensor,
|
seqlen: torch.Tensor,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
set_block_mapping,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention_mla,
|
paged_attention_mla,
|
||||||
|
set_block_mapping,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module):
|
|||||||
adapter_data: Optional[torch.Tensor],
|
adapter_data: Optional[torch.Tensor],
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
adapter_data: Optional[torch.Tensor],
|
adapter_data: Optional[torch.Tensor],
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -27,6 +27,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.wte(input_ids)
|
hidden_states = self.wte(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -43,6 +43,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
KVCache,
|
KVCache,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -548,6 +549,10 @@ class Llama4TextModel(nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
bs = seqlen.input_lengths.shape[0]
|
bs = seqlen.input_lengths.shape[0]
|
||||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -549,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
cross_attention_states=None,
|
cross_attention_states=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
|
@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
set_block_mapping,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -9,6 +9,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -266,7 +267,10 @@ class Qwen3Model(nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, inputs_embeds.shape[0]
|
||||||
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
|
@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
set_block_mapping,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hpu_attention_meta is not None:
|
||||||
|
hpu_attention_meta = set_block_mapping(
|
||||||
|
hpu_attention_meta, input_ids.shape[0]
|
||||||
|
)
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -153,19 +153,14 @@ def prepare_for_decode(
|
|||||||
block_list_device = _async_h2d_tensor_copy(block_list)
|
block_list_device = _async_h2d_tensor_copy(block_list)
|
||||||
block_groups_device = _async_h2d_tensor_copy(block_groups)
|
block_groups_device = _async_h2d_tensor_copy(block_groups)
|
||||||
block_usage_device = _async_h2d_tensor_copy(block_usage)
|
block_usage_device = _async_h2d_tensor_copy(block_usage)
|
||||||
block_mapping = torch.nn.functional.one_hot(
|
|
||||||
block_groups_device, num_classes=batch_size
|
|
||||||
)
|
|
||||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
|
||||||
mask = mask >= block_usage_device.unsqueeze(-1)
|
|
||||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
|
||||||
return trim_attn_metadata(
|
return trim_attn_metadata(
|
||||||
HPUPagedAttentionMetadata(
|
HPUPagedAttentionMetadata(
|
||||||
block_list=block_list_device,
|
block_list=block_list_device,
|
||||||
block_groups=block_groups_device,
|
block_groups=block_groups_device,
|
||||||
block_usage=block_usage_device,
|
block_usage=block_usage_device,
|
||||||
block_mapping=block_mapping.to(dtype),
|
block_mapping=None,
|
||||||
attn_bias=attn_bias,
|
attn_bias=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1298,7 +1293,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||||
)
|
)
|
||||||
all_input_ids_tensor = torch.zeros(
|
all_input_ids_tensor = torch.zeros(
|
||||||
(max_padded_bs, max_total_tokens), dtype=torch.int64, device="hpu"
|
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="hpu",
|
||||||
)
|
)
|
||||||
for i in range(len(self)):
|
for i in range(len(self)):
|
||||||
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
|
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
|
||||||
@ -2051,8 +2048,6 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
if batch.valid_indices is not None:
|
if batch.valid_indices is not None:
|
||||||
# TODO speculative decoding handling missing
|
# TODO speculative decoding handling missing
|
||||||
next_token_logprobs = next_token_logprobs.cpu()
|
|
||||||
accepted_ids = accepted_ids.cpu()
|
|
||||||
index = torch.arange(
|
index = torch.arange(
|
||||||
0,
|
0,
|
||||||
len(batch.valid_indices),
|
len(batch.valid_indices),
|
||||||
@ -2068,8 +2063,13 @@ class FlashCausalLM(Model):
|
|||||||
0, index, next_input_ids[batch.valid_indices]
|
0, index, next_input_ids[batch.valid_indices]
|
||||||
)
|
)
|
||||||
next_input_ids = next_input_ids[:padded_total_bs]
|
next_input_ids = next_input_ids[:padded_total_bs]
|
||||||
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
|
||||||
accepted_ids = accepted_ids[batch.valid_indices]
|
next_token_logprobs.index_copy_(
|
||||||
|
0, index, next_token_logprobs[batch.valid_indices]
|
||||||
|
)
|
||||||
|
accepted_ids.index_copy_(
|
||||||
|
0, index, accepted_ids[batch.valid_indices]
|
||||||
|
)
|
||||||
if speculative_ids is not None:
|
if speculative_ids is not None:
|
||||||
speculative_ids = speculative_ids[batch.valid_indices]
|
speculative_ids = speculative_ids[batch.valid_indices]
|
||||||
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
|
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
|
||||||
|
Loading…
Reference in New Issue
Block a user