multi-modality warmup

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-01 23:57:07 -07:00
parent 9d85ac9485
commit 705cc0b619
3 changed files with 345 additions and 21 deletions

View File

@ -1487,7 +1487,6 @@ class FlashCausalLM(Model):
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
del _batch, batch
self.kv_cache = []
empty_cache()
@ -1499,6 +1498,7 @@ class FlashCausalLM(Model):
self.kv_cache_dtype,
self.device,
)
self.bucketing_ctx = HPUBucketingContext(
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO
@ -1506,6 +1506,17 @@ class FlashCausalLM(Model):
num_blocks * BLOCK_SIZE,
)
self.bucketing_ctx.num_hpu_blocks = num_blocks
if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true":
logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def warmup_hpu_graph(self, batch):
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
@ -1513,14 +1524,13 @@ class FlashCausalLM(Model):
):
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size)
self.bucketing_ctx.generate_decode_buckets(num_blocks)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num)
synchronize(self.device)
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def warmup_prefill(self, prompt_len: int, bs: int):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")

View File

@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
prepare_for_decode,
)
from text_generation_server.models.globals import PREFIX_CACHING
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
tracer = trace.get_tracer(__name__)
@ -375,6 +380,80 @@ class FlashVlmCausalLM(FlashCausalLM):
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
if batch.position_ids is not None and batch.position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = position_ids.unsqueeze(-1).repeat(
(1, batch.position_ids.shape[-1])
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
def forward(
self,
batch: FlashVlmCausalLMBatch,
@ -450,17 +529,75 @@ class FlashVlmCausalLM(FlashCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
kwargs["bypass_hpu_graphs"] = batch.prefilling
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = F.pad(
position_ids,
(0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]),
value=1,
)
else:
position_ids = F.pad(
position_ids,
(0, (padded_bs - orig_bs) * input_seq.shape[-1]),
value=1,
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -476,8 +613,6 @@ class FlashVlmCausalLM(FlashCausalLM):
image_grid_thw=batch.image_grid_thw,
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:

View File

@ -11,7 +11,9 @@ from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.flash_causal_lm import (
prepare_for_decode,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
FlashVlmCausalLM,
@ -19,6 +21,12 @@ from text_generation_server.models.flash_vlm_causal_lm import (
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
from loguru import logger
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
tracer = trace.get_tracer(__name__)
@ -197,6 +205,131 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
class FlashMllamaCausalLM(FlashVlmCausalLM):
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
cross_attention_states=batch.cross_attention_states,
image_indices=batch.image_indices[:],
)
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
input_ids = torch.zeros(
prompt_len, dtype=torch.int64, device=self.device
).repeat(bs)
position_ids = torch.arange(
prompt_len, dtype=torch.int32, device=self.device
).repeat(bs)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).reshape(bs, -1)
slot_acc = []
for i in range(bs):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
)
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=lm_head_indices,
cross_attention_states=batch.cross_attention_states,
adapter_data=None,
hpu_attention_meta=None,
image_indices=batch.image_indices[:],
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
def forward(
self,
batch: FlashMllamaCausalLMBatch,
@ -263,12 +396,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
@ -286,6 +413,60 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
position_ids = F.pad(
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -301,8 +482,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
image_indices=batch.image_indices[:],
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits