This commit is contained in:
Wang, Yi 2025-04-18 19:54:59 +05:30 committed by GitHub
commit 87840ab374
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 921 additions and 559 deletions

View File

@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
"block_list",
"block_mapping",
"block_usage",
"block_scales",
"block_groups",
"attn_bias",
],

View File

@ -74,7 +74,6 @@ def paged_attention(
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),

View File

@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module):
# bsz, q_len, _ = hidden_states.size()
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
cross_attention_len,
indices,
) = cross_attention_states
bs = cu_seqlen_q.size(0) - 1
bs = cross_attention_len.size(0)
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
indices=None,
cross_attention_len: Optional[torch.Tensor] = None,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
assert max(indices) < input_ids.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
indices = image_indices[:]
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
cross_attention_len,
indices,
)

View File

@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
prepare_for_decode,
)
from text_generation_server.models.globals import PREFIX_CACHING
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
tracer = trace.get_tracer(__name__)
@ -375,6 +380,91 @@ class FlashVlmCausalLM(FlashCausalLM):
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
):
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
if batch.position_ids is not None and batch.position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = position_ids.unsqueeze(-1).repeat(
(1, batch.position_ids.shape[-1])
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots_tensor,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
pixel_values=None,
pixel_attention_mask=None,
image_sizes=None,
image_grid_thw=None,
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
def forward(
self,
batch: FlashVlmCausalLMBatch,
@ -450,17 +540,75 @@ class FlashVlmCausalLM(FlashCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
kwargs["bypass_hpu_graphs"] = batch.prefilling
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = F.pad(
position_ids,
(0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]),
value=1,
)
else:
position_ids = F.pad(
position_ids,
(0, (padded_bs - orig_bs) * input_seq.shape[-1]),
value=1,
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -476,8 +624,6 @@ class FlashVlmCausalLM(FlashCausalLM):
image_grid_thw=batch.image_grid_thw,
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
@ -486,4 +632,6 @@ class FlashVlmCausalLM(FlashCausalLM):
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)

View File

@ -11,7 +11,9 @@ from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.flash_causal_lm import (
prepare_for_decode,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
FlashVlmCausalLM,
@ -19,6 +21,13 @@ from text_generation_server.models.flash_vlm_causal_lm import (
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
from loguru import logger
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
from text_generation_server.utils.log import log_master
tracer = trace.get_tracer(__name__)
@ -196,7 +205,178 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
return batch
def generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
):
if cross_attention_states is None:
return None, None, None
device = cross_attention_states.device
indices_list = []
if prefilling:
for i in image_indices:
indices_list.append(
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
)
indices = torch.cat(indices_list, dim=0)
else:
indices = image_indices[:]
return indices, seqlen.input_lengths.index_select(0, image_indices)
class FlashMllamaCausalLM(FlashVlmCausalLM):
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices, device=self.device)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots_tensor,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(
prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(batch_size)
position_ids = torch.arange(
prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(batch_size)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = (
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
)
cache_lengths_tensor = torch.zeros(
batch_size, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices, device=self.device)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, prompt_len, True
)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=None,
lm_head_indices=lm_head_indices,
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
def forward(
self,
batch: FlashMllamaCausalLMBatch,
@ -263,12 +443,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
@ -281,11 +455,82 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
kwargs["bypass_hpu_graphs"] = batch.prefilling
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
image_indices = torch.tensor(batch.image_indices, device=self.device)
if padded_bs != input_lengths.shape[0]:
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
)
position_ids = F.pad(
position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
)
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
if cross_attention_states is not None:
cross_attention_states = F.pad(
cross_attention_states,
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
value=0,
)
if len(image_indices) != 0:
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
image_indices = torch.cat((image_indices, pad_indices), dim=0)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states,
image_indices,
seqlen,
padded_input_len,
batch.prefilling,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -295,14 +540,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
# TODO list
adapter_data=None,
image_indices=batch.image_indices[:],
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)

View File

@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator {
(required_blocks, repeats)
};
let tokens = tokens as usize;
let mut tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator {
.split_off(self.free_blocks.len() - required_blocks as usize);
if self.is_hpu_device {
blocks.sort();
// need 1 slot for ping-pong optimization
tokens += 1;
}
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);