From 705cc0b6195c7a7572d85d7a3acff563e9af32d1 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 1 Apr 2025 23:57:07 -0700 Subject: [PATCH] multi-modality warmup Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 16 +- .../models/flash_vlm_causal_lm.py | 153 +++++++++++++- .../models/mllama_causal_lm.py | 197 +++++++++++++++++- 3 files changed, 345 insertions(+), 21 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ed8b658a..48165256 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1487,7 +1487,6 @@ class FlashCausalLM(Model): if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 - del _batch, batch self.kv_cache = [] empty_cache() @@ -1499,6 +1498,7 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + self.bucketing_ctx = HPUBucketingContext( os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO @@ -1506,6 +1506,17 @@ class FlashCausalLM(Model): num_blocks * BLOCK_SIZE, ) self.bucketing_ctx.num_hpu_blocks = num_blocks + if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true": + logger.info("skip warmup hpu graph, not recommmended") + del _batch, batch + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + + self.warmup_hpu_graph(batch) + del _batch, batch + + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + + def warmup_hpu_graph(self, batch): warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() for i, (batch_size, seq_len) in enumerate( @@ -1513,14 +1524,13 @@ class FlashCausalLM(Model): ): for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size) - self.bucketing_ctx.generate_decode_buckets(num_blocks) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): for index in range(warmup_times): self.warmup_decode(batch_size, block_num) synchronize(self.device) - return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): logger.info(f"warmup prefill seq {prompt_len} bs {bs}") diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 208ab358..2f9de99f 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + prepare_for_decode, ) -from text_generation_server.models.globals import PREFIX_CACHING +from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata import habana_frameworks.torch as htorch +from text_generation_server.utils.import_utils import ( + synchronize, +) +import torch.nn.functional as F tracer = trace.get_tracer(__name__) @@ -375,6 +380,80 @@ class FlashVlmCausalLM(FlashCausalLM): def max_past(self) -> Optional[int]: return getattr(self.model.text_model, "max_past", None) + def warmup_decode( + self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch + ): + logger.info(f"warmup decode bs {batch_size} block_num {block_num}") + input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) + position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + if batch.position_ids is not None and batch.position_ids.dim() == 2: + # qwen2_vl and qwen2_5_vl case + position_ids = position_ids.unsqueeze(-1).repeat( + (1, batch.position_ids.shape[-1]) + ) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + batch_size, + bucketing_ctx=None, + ) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, + ) + + def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + def forward( self, batch: FlashVlmCausalLMBatch, @@ -450,17 +529,75 @@ class FlashVlmCausalLM(FlashCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + orig_bs = input_lengths.shape[0] + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + if position_ids.dim() == 2: + # qwen2_vl and qwen2_5_vl case + position_ids = F.pad( + position_ids, + (0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]), + value=1, + ) + else: + position_ids = F.pad( + position_ids, + (0, (padded_bs - orig_bs) * input_seq.shape[-1]), + value=1, + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -476,8 +613,6 @@ class FlashVlmCausalLM(FlashCausalLM): image_grid_thw=batch.image_grid_thw, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None if batch.pixel_attention_mask is not None: diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index e034ed49..55d80ca5 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -11,7 +11,9 @@ from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) - +from text_generation_server.models.flash_causal_lm import ( + prepare_for_decode, +) from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLMBatch, FlashVlmCausalLM, @@ -19,6 +21,12 @@ from text_generation_server.models.flash_vlm_causal_lm import ( from text_generation_server.pb import generate_pb2 from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata import habana_frameworks.torch as htorch +from loguru import logger +from text_generation_server.models.globals import BLOCK_SIZE +from text_generation_server.utils.import_utils import ( + synchronize, +) +import torch.nn.functional as F tracer = trace.get_tracer(__name__) @@ -197,6 +205,131 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): class FlashMllamaCausalLM(FlashVlmCausalLM): + def warmup_decode( + self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch + ): + logger.info(f"warmup decode bs {batch_size} block_num {block_num}") + input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) + position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + batch_size, + bucketing_ctx=None, + ) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, + cross_attention_states=batch.cross_attention_states, + image_indices=batch.image_indices[:], + ) + + def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): + logger.info(f"warmup prefill seq {prompt_len} bs {bs}") + input_ids = torch.zeros( + prompt_len, dtype=torch.int64, device=self.device + ).repeat(bs) + position_ids = torch.arange( + prompt_len, dtype=torch.int32, device=self.device + ).repeat(bs) + max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slot_acc = [] + for i in range(bs): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + + input_lengths = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + lm_head_indices = input_lengths - 1 + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=lm_head_indices, + cross_attention_states=batch.cross_attention_states, + adapter_data=None, + hpu_attention_meta=None, + image_indices=batch.image_indices[:], + ) + + def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + warmup_times = 3 + self.bucketing_ctx.generate_prompt_buckets() + for i, (batch_size, seq_len) in enumerate( + reversed(self.bucketing_ctx.prompt_buckets) + ): + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + def forward( self, batch: FlashMllamaCausalLMBatch, @@ -263,12 +396,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - if batch.pixel_values is not None: cross_attention_states = self.model.vision_forward( pixel_values=batch.pixel_values, @@ -286,6 +413,60 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + orig_bs = input_lengths.shape[0] + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + position_ids = F.pad( + position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -301,8 +482,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): image_indices=batch.image_indices[:], **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None return logits, speculative_logits