mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
[gaudi] Perf optimization (#3256)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
79183d1647
commit
839477670a
@ -11,6 +11,7 @@ from .hpu import (
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
set_block_mapping,
|
||||
)
|
||||
|
||||
|
||||
@ -22,6 +23,7 @@ __all__ = [
|
||||
"get_kv_scales",
|
||||
"paged_attention",
|
||||
"paged_attention_mla",
|
||||
"set_block_mapping",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"KVCompressCache",
|
||||
|
@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import os
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
import math
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -106,6 +107,21 @@ def attention(
|
||||
return attn_output
|
||||
|
||||
|
||||
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
|
||||
block_mapping = torch.nn.functional.one_hot(
|
||||
hpu_attention_meta.block_groups, num_classes=batch_size
|
||||
)
|
||||
dtype = hpu_attention_meta.block_usage.dtype
|
||||
device = hpu_attention_meta.block_usage.device
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
hpu_attention_meta = hpu_attention_meta._replace(
|
||||
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
|
||||
)
|
||||
return hpu_attention_meta
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
@ -176,4 +192,10 @@ def paged_attention_mla(
|
||||
return output.view(batch_size, head_num, -1)
|
||||
|
||||
|
||||
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|
||||
__all__ = [
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"paged_attention_mla",
|
||||
"set_block_mapping",
|
||||
]
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module):
|
||||
seqlen: torch.Tensor,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention_mla,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -27,6 +27,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
|
@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.wte(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -43,6 +43,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -548,6 +549,10 @@ class Llama4TextModel(nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
bs = seqlen.input_lengths.shape[0]
|
||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -549,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
cross_attention_states=None,
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -9,6 +9,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = self.model(
|
||||
|
@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -266,7 +267,10 @@ class Qwen3Model(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
hidden_states = self.model(
|
||||
|
@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
|
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -153,19 +153,14 @@ def prepare_for_decode(
|
||||
block_list_device = _async_h2d_tensor_copy(block_list)
|
||||
block_groups_device = _async_h2d_tensor_copy(block_groups)
|
||||
block_usage_device = _async_h2d_tensor_copy(block_usage)
|
||||
block_mapping = torch.nn.functional.one_hot(
|
||||
block_groups_device, num_classes=batch_size
|
||||
)
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= block_usage_device.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
block_list=block_list_device,
|
||||
block_groups=block_groups_device,
|
||||
block_usage=block_usage_device,
|
||||
block_mapping=block_mapping.to(dtype),
|
||||
attn_bias=attn_bias,
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -428,10 +423,8 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
# put on cpu temporarily, move to hpu in prepare_for_prefill
|
||||
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||
|
||||
@ -701,7 +694,9 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
def concatenate(
|
||||
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
|
||||
) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -750,7 +745,10 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_meta = None
|
||||
adapter_segment_builder = None
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if padded_total_bs == batches[0].input_ids.shape[0]:
|
||||
input_ids = batches[0].input_ids
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if (
|
||||
batches[0].position_ids is not None
|
||||
and batches[0].position_ids.dim() == 2
|
||||
@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||
(total_batch_size, max_blocks)
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
@ -829,9 +825,12 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
|
||||
if i > 0:
|
||||
all_input_ids_tensor.index_copy_(
|
||||
0,
|
||||
index.to(batch.all_input_ids_tensor.device),
|
||||
batch.all_input_ids_tensor[:valid_bsize, :],
|
||||
)
|
||||
|
||||
block_tables_tensor[
|
||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||
@ -851,9 +850,10 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||
slot_indices.index_copy_(
|
||||
0, index, batch.slot_indices + cumulative_slots
|
||||
@ -987,7 +987,6 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
padded_bs = self.input_ids.shape[0]
|
||||
slots = self.slots[self.slot_indices]
|
||||
extra_pad = padded_bs - self.input_ids.shape[0]
|
||||
|
||||
self.hpu_attn_meta = prepare_for_decode(
|
||||
dtype,
|
||||
@ -998,17 +997,20 @@ class FlashCausalLMBatch(Batch):
|
||||
padded_bs,
|
||||
bucketing_ctx,
|
||||
)
|
||||
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
|
||||
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
|
||||
)
|
||||
self.position_ids = F.pad(
|
||||
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
|
||||
)
|
||||
self.input_lengths_tensor = F.pad(
|
||||
self.input_lengths_tensor, (0, extra_pad), value=0
|
||||
self.input_lengths_tensor,
|
||||
(0, padded_bs - self.input_lengths_tensor.shape[0]),
|
||||
value=0,
|
||||
)
|
||||
self.cache_lengths_tensor = F.pad(
|
||||
self.cache_lengths_tensor, (0, extra_pad), value=0
|
||||
)
|
||||
self.all_input_ids_tensor = F.pad(
|
||||
self.all_input_ids_tensor,
|
||||
(0, 0, 0, extra_pad),
|
||||
self.cache_lengths_tensor,
|
||||
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
|
||||
value=0,
|
||||
)
|
||||
next_token_chooser_parameters = []
|
||||
@ -1028,7 +1030,9 @@ class FlashCausalLMBatch(Batch):
|
||||
fsm_grammar_states,
|
||||
)
|
||||
|
||||
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
):
|
||||
# Prepare values if we need to continue prefilling
|
||||
# Speculation must be ignored while we prefill even with chunking
|
||||
# it simplifies everything
|
||||
@ -1044,7 +1048,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# need extra pad to match warmup seq
|
||||
extra_pad = max_padded_input_len - self.max_input_length
|
||||
extra_pad_bs = max_padded_bs - len(self)
|
||||
device = self.all_input_ids_tensor.device
|
||||
device = "hpu"
|
||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||
input_ids_padded_length = []
|
||||
input_ids = []
|
||||
@ -1288,12 +1292,17 @@ class FlashCausalLMBatch(Batch):
|
||||
self.prefill_next_token_indices = (
|
||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
|
||||
self.all_input_ids_tensor = F.pad(
|
||||
self.all_input_ids_tensor,
|
||||
(0, 0, 0, extra_pad_bs),
|
||||
value=0,
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
|
||||
dtype=torch.int64,
|
||||
device="hpu",
|
||||
)
|
||||
for i in range(len(self)):
|
||||
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
|
||||
self.all_input_ids_tensor[i]
|
||||
)
|
||||
self.all_input_ids_tensor = all_input_ids_tensor
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
|
||||
@ -1459,6 +1468,8 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
self.bucketing_ctx = None
|
||||
self.max_total_tokens = None
|
||||
self.max_input_tokens = None
|
||||
htorch.core.hpu_set_env()
|
||||
if htorch.utils.internal.is_lazy():
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
@ -1564,6 +1575,14 @@ class FlashCausalLM(Model):
|
||||
logger.info,
|
||||
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
|
||||
)
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = sum(batch.input_lengths)
|
||||
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
|
||||
self.max_total_tokens = max_total_tokens
|
||||
self.max_input_tokens = max_input_tokens
|
||||
try:
|
||||
self.init_kv_cache(
|
||||
batch.num_blocks,
|
||||
@ -1597,11 +1616,6 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = sum(batch.input_lengths)
|
||||
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
@ -2017,7 +2031,9 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
batch.all_input_ids_tensor[
|
||||
: batch.next_token_logits.shape[0], : batch.max_current_length
|
||||
],
|
||||
batch.next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
@ -2031,14 +2047,29 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
if batch.valid_indices is not None:
|
||||
next_token_logprobs = next_token_logprobs.cpu()
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||
batch.valid_indices
|
||||
]
|
||||
next_input_ids = next_input_ids[batch.valid_indices]
|
||||
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
||||
accepted_ids = accepted_ids[batch.valid_indices]
|
||||
# TODO speculative decoding handling missing
|
||||
index = torch.arange(
|
||||
0,
|
||||
len(batch.valid_indices),
|
||||
device=batch.all_input_ids_tensor.device,
|
||||
)
|
||||
batch.all_input_ids_tensor.index_copy_(
|
||||
0, index, batch.all_input_ids_tensor[batch.valid_indices]
|
||||
)
|
||||
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
len(batch.valid_indices)
|
||||
)
|
||||
next_input_ids.index_copy_(
|
||||
0, index, next_input_ids[batch.valid_indices]
|
||||
)
|
||||
next_input_ids = next_input_ids[:padded_total_bs]
|
||||
|
||||
next_token_logprobs.index_copy_(
|
||||
0, index, next_token_logprobs[batch.valid_indices]
|
||||
)
|
||||
accepted_ids.index_copy_(
|
||||
0, index, accepted_ids[batch.valid_indices]
|
||||
)
|
||||
if speculative_ids is not None:
|
||||
speculative_ids = speculative_ids[batch.valid_indices]
|
||||
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
|
||||
@ -2106,10 +2137,13 @@ class FlashCausalLM(Model):
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = F.pad(
|
||||
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
|
||||
)
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
index.shape[0],
|
||||
dtype=torch.long,
|
||||
device=batch.all_input_ids_tensor.device,
|
||||
)
|
||||
@ -2197,7 +2231,18 @@ class FlashCausalLM(Model):
|
||||
htorch.core.mark_step()
|
||||
# Stage 2. Prepare new batch for speculative scheduling
|
||||
if len(batches) > 1:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
if self.bucketing_ctx is not None:
|
||||
total_batch_size = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
total_batch_size
|
||||
)
|
||||
batch = self.batch_type.concatenate(
|
||||
batches, padded_total_bs=padded_total_bs
|
||||
)
|
||||
else:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
else:
|
||||
batch = batches[0]
|
||||
prefill = batch.prefilling
|
||||
@ -2208,9 +2253,12 @@ class FlashCausalLM(Model):
|
||||
batch.max_input_length
|
||||
),
|
||||
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
|
||||
self.max_total_tokens,
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_prefill(batch.max_input_length, len(batch))
|
||||
batch.prepare_for_prefill(
|
||||
batch.max_input_length, len(batch), self.max_total_tokens
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_decode(
|
||||
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
||||
|
@ -262,8 +262,8 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
|
||||
def concatenate(cls, batches, padded_total_bs: int = 0):
|
||||
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
|
@ -48,8 +48,8 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super().concatenate(batches)
|
||||
def concatenate(cls, batches, padded_total_bs: int = 0):
|
||||
batch = super().concatenate(batches, padded_total_bs)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user