From fd70ad703e960bd4589c4d04c8ca725d15179d0b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 25 Mar 2025 22:21:44 -0700 Subject: [PATCH] warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A --- .../text_generation_server/models/__init__.py | 17 -- .../models/flash_causal_lm.py | 209 ++++++++++-------- .../models/flash_vlm_causal_lm.py | 2 +- .../models/mllama_causal_lm.py | 2 +- .../server/text_generation_server/server.py | 2 - 5 files changed, 117 insertions(+), 115 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 7dac910e..778b14a1 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -92,7 +92,6 @@ try: from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch from text_generation_server.models.custom_modeling.flash_mllama import ( FlashMllamaForConditionalGeneration, @@ -144,7 +143,6 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(IdeficsCausalLM) class ModelType(enum.Enum): @@ -301,12 +299,6 @@ class ModelType(enum.Enum): "name": "Gptj", "url": "https://huggingface.co/EleutherAI/gpt-j-6b", } - IDEFICS = { - "type": "idefics", - "name": "Idefics", - "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", - "multimodal": True, - } MLLAMA = { "type": "mllama", "name": "Mllama", @@ -733,15 +725,6 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - elif model_type == IDEFICS: - return IdeficsCausalLM( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) elif model_type == QWEN2_VL: return FlashVlmCausalLM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 4cdf2628..b26184e4 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -69,6 +69,8 @@ from text_generation_server.utils.import_utils import ( import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch +import itertools +from vllm_hpu_extension.ops import batch2block, block2batch tracer = trace.get_tracer(__name__) @@ -86,6 +88,78 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW +def prepare_for_decode( + dtype, use_contiguous_pa, device, slot, block_tables, batch_size +): + # Prepare values if we need to continue decoding + # need for HPUPagedAttentionMetadata preparation + def flatten(in_list): + return list(itertools.chain(*in_list)) + + def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + def pad_list(input, k, v): + input_len = len(input) + target_len = (input_len + k - 1) // k * k + padding = target_len - input_len + return input + [v] * padding + + last_block_usage = slot % BLOCK_SIZE + 1 + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [ + [BLOCK_SIZE] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt + ] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + if use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + # block_bucket_size) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + block_list = gather_list(block_list, indices, 0) + block_groups = gather_list(block_groups, indices, -1) + block_usage = gather_list(block_usage, indices, 1) + else: + block_bucket_size = len(block_list) + block_list = pad_list(block_list, block_bucket_size, 0) + block_groups = pad_list(block_groups, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + + block_list = torch.tensor(block_list, dtype=torch.int, device=device) + block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) + block_usage = torch.tensor(block_usage, dtype=dtype, device=device) + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + ones = torch.ones( + (block_mapping.size(0),), device=device, dtype=block_mapping.dtype + ) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + return trim_attn_metadata( + HPUPagedAttentionMetadata( + block_list=block_list, + block_groups=block_groups, + block_usage=block_usage, + block_mapping=block_mapping.to(dtype), + attn_bias=attn_bias, + block_scales=block_scales, + ) + ) + + @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -879,83 +953,18 @@ class FlashCausalLMBatch(Batch): ) def prepare_for_decode(self, dtype, use_contiguous_pa): - # Prepare values if we need to continue decoding - # need for HPUPagedAttentionMetadata preparation - import itertools - from vllm_hpu_extension.ops import batch2block, block2batch - - def flatten(in_list): - return list(itertools.chain(*in_list)) - - def gather_list(input, indices, v): - return [input[i] if i is not None else v for i in indices] - - def pad_list(input, k, v): - input_len = len(input) - target_len = (input_len + k - 1) // k * k - padding = target_len - input_len - return input + [v] * padding - - device = self.block_tables_tensor.device - last_block_usage = self.slots[self.slot_indices] % BLOCK_SIZE + 1 block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) - block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] - block_usage = [ - [BLOCK_SIZE] * (len(bt) - 1) + [lbu] - for bt, lbu in zip(block_tables, last_block_usage) - if bt - ] - block_list = flatten(block_tables) - block_groups = flatten(block_groups) - block_usage = flatten(block_usage) - batch = self.input_ids.size(0) - - assert len(block_list) == len(block_groups) - assert len(block_list) == len(block_usage) - if use_contiguous_pa: - block_bucket_size = max(max(block_list) + 1, len(block_list)) - # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - # block_bucket_size) - indices: List[Any] - indices = [None] * block_bucket_size - for i, bid in enumerate(block_list): - indices[bid] = i - block_list = gather_list(block_list, indices, 0) - block_groups = gather_list(block_groups, indices, -1) - block_usage = gather_list(block_usage, indices, 1) - else: - block_bucket_size = len(block_list) - block_list = pad_list(block_list, block_bucket_size, 0) - block_groups = pad_list(block_groups, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - - block_list = torch.tensor(block_list, dtype=torch.int, device=device) - block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) - block_usage = torch.tensor(block_usage, dtype=dtype, device=device) - block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch) - mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze( - 0 - ) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) - ones = torch.ones( - (block_mapping.size(0),), device=device, dtype=block_mapping.dtype - ) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) - self.hpu_attn_meta = trim_attn_metadata( - HPUPagedAttentionMetadata( - block_list=block_list, - block_groups=block_groups, - block_usage=block_usage, - block_mapping=block_mapping.to(dtype), - attn_bias=attn_bias, - block_scales=block_scales, - ) + self.hpu_attn_meta = prepare_for_decode( + dtype, + use_contiguous_pa, + self.block_tables_tensor.device, + self.slots[self.slot_indices], + block_tables, + self.input_ids.size(0), ) def prepare_for_prefill(self): @@ -1481,32 +1490,44 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + for bs in [1, 2, 4, 8]: + for seqlen in [32, 64, 128, 256, 512, 1024]: + self.warmup_prefill(seqlen, bs) return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def tunableop_warmup(self, seqlen: int, max_bt: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.zeros( - seqlen, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ) - + def warmup_prefill(self, prompt_len: int, bs: int): + input_ids = torch.zeros( + prompt_len, dtype=torch.int64, device=self.device + ).repeat(bs) + position_ids = torch.arange( + prompt_len, dtype=torch.int32, device=self.device + ).repeat(bs) + max_bt = (prompt_len // BLOCK_SIZE + 1) * bs block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).repeat(seqlen) - block_tables = block_tables.reshape((seqlen, max_bt)) + ).reshape(bs, -1) + slot_acc = [] + for i in range(bs): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + + input_lengths = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, ) + lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1514,11 +1535,13 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=block_tables, - seqlen=seqlen, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, - lm_head_indices=None, + seqlen=trim_seqlen_metadata(seqlen), prefill_cache_indices=None, + lm_head_indices=lm_head_indices, + adapter_data=None, + hpu_attention_meta=None, ) def forward( @@ -1606,7 +1629,7 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, + block_tables=None, slots=slots, seqlen=trim_seqlen_metadata(seqlen), prefill_cache_indices=batch.prefill_cache_indices, @@ -1637,9 +1660,7 @@ class FlashCausalLM(Model): batch.prepare_for_prefill() else: batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) - prefill_logprobs = batch.prefill_next_token_indices is not None - # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta if batch.speculative_ids is not None: diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 7cff7797..48bfce89 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -462,7 +462,7 @@ class FlashVlmCausalLM(FlashCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index be67b6ae..4471aab3 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -288,7 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 6e470361..5a7d2117 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -33,13 +33,11 @@ try: from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLMBatch, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch VLM_BATCH_TYPES = { PaliGemmaBatch, VlmCausalLMBatch, FlashVlmCausalLMBatch, - IdeficsCausalLMBatch, FlashMllamaCausalLMBatch, } except (ImportError, NotImplementedError):