diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 8ec9fb46..34c77040 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata: block_list: Optional[torch.Tensor] block_mapping: Optional[torch.Tensor] block_usage: Optional[torch.Tensor] - block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor] @@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: "block_list", "block_mapping", "block_usage", - "block_scales", "block_groups", "attn_bias", ], diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index f34e93ab..1d73dcb3 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -74,7 +74,6 @@ def paged_attention( block_list=hpu_attention_meta.block_list, block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, - block_scales=hpu_attention_meta.block_scales, block_groups=hpu_attention_meta.block_groups, scale=softmax_scale, matmul_qk_op=Matmul(), diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index 216642e0..421a0a65 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module): # bsz, q_len, _ = hidden_states.size() ( cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, + cross_attention_len, indices, ) = cross_attention_states - bs = cu_seqlen_q.size(0) - 1 + bs = cross_attention_len.size(0) query_states = self.q_proj(hidden_states) query_states = query_states.view(bs, -1, self.num_heads, self.head_size) query_states = self.q_norm(query_states) @@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module): indices = cross_attention_states[-1] out_hidden_states = hidden_states[:] - if len(indices) > 0: - assert max(indices) < hidden_states.shape[0] hidden_states = hidden_states[indices] residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, - # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, - image_indices=None, + indices=None, + cross_attention_len: Optional[torch.Tensor] = None, ): if cross_attention_states is not None: - seqlen_q = len(image_indices) - n_images = cross_attention_states.shape[0] - seqlen_k = cross_attention_states.shape[1] - device = cross_attention_states.device - if cu_seqlen_prefill is not None: - offset = 0 - cu_q = [] - indices = [] - for index in image_indices: - cu_q.append(offset) - length = seqlen.input_lengths[index].item() - assert index < seqlen.cu_seqlen_q.shape[0] - input_ids_offset = seqlen.cu_seqlen_q[index] - indices.extend(range(input_ids_offset, input_ids_offset + length)) - offset += length - cu_q.append(offset) - cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) - - assert max(indices) < input_ids.shape[0] - - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - else: - cu_seqlen_q = torch.arange( - seqlen_q + 1, device=device, dtype=torch.int32 - ) - seqlen_k = cross_attention_states.shape[1] - n_images = cross_attention_states.shape[0] - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - indices = image_indices[:] - cross_attention_states = ( cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, + cross_attention_len, indices, ) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index a4d58596..ecedd4aa 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -70,7 +70,7 @@ from text_generation_server.utils.import_utils import ( import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.ops import batch2block, block2batch +from vllm_hpu_extension.bucketing import HPUBucketingContext tracer = trace.get_tracer(__name__) @@ -89,7 +89,7 @@ def get_sliding_windows() -> int: def prepare_for_decode( - dtype, use_contiguous_pa, device, slot, block_tables, batch_size + dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx ): # Prepare values if we need to continue decoding # need for HPUPagedAttentionMetadata preparation @@ -105,7 +105,7 @@ def prepare_for_decode( padding = target_len - input_len return input + [v] * padding - last_block_usage = slot % BLOCK_SIZE + 1 + last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots] block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] block_usage = [ [BLOCK_SIZE] * (len(bt) - 1) + [lbu] @@ -120,8 +120,10 @@ def prepare_for_decode( assert len(block_list) == len(block_usage) if use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - # block_bucket_size) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -131,6 +133,10 @@ def prepare_for_decode( block_usage = gather_list(block_usage, indices, 1) else: block_bucket_size = len(block_list) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) block_list = pad_list(block_list, block_bucket_size, 0) block_groups = pad_list(block_groups, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) @@ -142,11 +148,6 @@ def prepare_for_decode( mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= block_usage.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) - ones = torch.ones( - (block_mapping.size(0),), device=device, dtype=block_mapping.dtype - ) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list, @@ -154,7 +155,6 @@ def prepare_for_decode( block_usage=block_usage, block_mapping=block_mapping.to(dtype), attn_bias=attn_bias, - block_scales=block_scales, ) ) @@ -246,6 +246,9 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta: Optional[HPUPagedAttentionMetadata] + next_token_logits: Optional[torch.Tensor] + speculative_logits: Optional[torch.Tensor] + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -321,6 +324,8 @@ class FlashCausalLMBatch(Batch): ### Deactivating it by default seems like the best course. if not REQUEST_LOGPROBS: r.prefill_logprobs = False + else: + assert False, "prefill_logprobs not supported yet" # request id -> idx in list mapping requests_idx_mapping[r.id] = i @@ -481,6 +486,8 @@ class FlashCausalLMBatch(Batch): input_lengths_tensor=None, adapter_meta=None, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) @classmethod @@ -608,6 +615,12 @@ class FlashCausalLMBatch(Batch): max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] + next_token_logits = self.next_token_logits[indices] + speculative_logits = ( + self.speculative_logits[indices] + if self.speculative_logits is not None + else None + ) block_tables_tensor = self.block_tables_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] @@ -689,6 +702,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, + next_token_logits=next_token_logits, + speculative_logits=speculative_logits, ) @classmethod @@ -816,8 +831,11 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - # Copy tensors (GPU) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # Copy tensors (HPU) + index = torch.tensor( + list(range(start_index, end_index)), device=batch.input_ids.device + ) + top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -825,7 +843,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor + prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor) slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) @@ -835,15 +853,13 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = ( - batch.slot_indices + cumulative_slots + input_ids.index_copy_(0, index, batch.input_ids) + position_ids.index_copy_(0, index, batch.position_ids) + slot_indices.index_copy_( + 0, index, batch.slot_indices + cumulative_slots ) - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor - - # Copy over adapter indices + input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor) + cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor) adapter_start_index = cumulative_adapter_indices_size adapter_end_index = ( cumulative_adapter_indices_size @@ -949,24 +965,38 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) - def prepare_for_decode(self, dtype, use_contiguous_pa): - block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 + def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): + block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths] block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) + if bucketing_ctx is not None: + padded_bs = bucketing_ctx.get_padded_decode_batch_size( + self.input_ids.shape[0] + ) + else: + padded_bs = self.input_ids.shape[0] + slots = self.slots[self.slot_indices] + extra_pad = padded_bs - self.input_ids.shape[0] + if extra_pad != 0: + slots = F.pad(slots, (0, extra_pad), value=0) + block_tables.extend([[0]] * extra_pad) self.hpu_attn_meta = prepare_for_decode( dtype, use_contiguous_pa, self.block_tables_tensor.device, - self.slots[self.slot_indices], + slots.cpu(), block_tables, - self.input_ids.size(0), + padded_bs, + bucketing_ctx, ) - def prepare_for_prefill(self): + def prepare_for_prefill(self, max_padded_input_len): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -980,7 +1010,7 @@ class FlashCausalLMBatch(Batch): # the right logit position input_ids_padded_length = [] # need extra pad to match warmup seq - extra_pad = 0 + extra_pad = max_padded_input_len - self.max_input_length if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1355,9 +1385,9 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - + self.bucketing_ctx = None if htorch.utils.internal.is_lazy(): - htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" @@ -1462,12 +1492,11 @@ class FlashCausalLM(Model): log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") if max_total_tokens is None: - max_total_tokens = sum(batch.cache_lengths) + max_total_tokens = sum(batch.input_lengths) if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 - del _batch, batch self.kv_cache = [] empty_cache() @@ -1479,32 +1508,77 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + + self.bucketing_ctx = HPUBucketingContext( + os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO + os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO + BLOCK_SIZE, + num_blocks * BLOCK_SIZE, + False, + ) + self.bucketing_ctx.num_hpu_blocks = num_blocks + if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": + logger.info("skip warmup hpu graph, not recommmended") + del _batch, batch + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + + self.warmup_hpu_graph(batch) + del _batch, batch + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def warmup_prefill(self, prompt_len: int, bs: int): + def warmup_hpu_graph(self, batch): + warmup_times = 3 + self.bucketing_ctx.generate_prompt_buckets() + for i, (batch_size, seq_len) in enumerate( + reversed(self.bucketing_ctx.prompt_buckets) + ): + log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + if batch_size > block_num: + continue + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + + def warmup_prefill( + self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch + ): input_ids = torch.zeros( - prompt_len, dtype=torch.int64, device=self.device - ).repeat(bs) + prompt_len, dtype=batch.input_ids.dtype, device=self.device + ).repeat(batch_size) position_ids = torch.arange( - prompt_len, dtype=torch.int32, device=self.device - ).repeat(bs) - max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + prompt_len, dtype=batch.position_ids.dtype, device=self.device + ).repeat(batch_size) + max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + ).reshape(batch_size, -1) slot_acc = [] - for i in range(bs): + for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( @@ -1527,25 +1601,34 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) - def warmup_decode(self, bs: int, block_num: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - start=1, end=block_num + 1, dtype=torch.int32, device=self.device - ).reshape(bs, -1) - slots = [] - past_len = ( - len(block_tables[0]) * BLOCK_SIZE - 1 - ) # for decode, we only need to pass the past token - # fetch the last blocked to warmup block num - for i in range(bs): - slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) - cache_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch): + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( @@ -1553,27 +1636,24 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - block_num = cache_lengths_tensor // BLOCK_SIZE + 1 - block_tables_valid = [] - for i, bt in enumerate(block_tables.tolist()): - block_tables_valid.append(bt[0 : block_num[i]]) hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, self.device, slots, - block_tables_valid, - bs, + block_tables, + batch_size, + bucketing_ctx=None, ) - + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), lm_head_indices=None, adapter_data=None, @@ -1651,19 +1731,68 @@ class FlashCausalLM(Model): # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - kwargs = {} - if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + position_ids = F.pad( + position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = batch.prefilling + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1677,12 +1806,174 @@ class FlashCausalLM(Model): hpu_attention_meta=batch.hpu_attn_meta, **kwargs, ) - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) @tracer.start_as_current_span("generate_token") def generate_token( self, batches: List[FlashCausalLMBatch] ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + prev_batches = [] + requests_to_generate = [] + for batch_id, batch in enumerate(batches): + if batch.next_token_logits is not None: + prefill = batch.prefilling + if batch.prefilling: + batch.prefilling = False + batch.prefilling_mask = [False] * len(batch) + + speculate = get_speculate() + ( + next_input_ids, + next_token_logprobs, + logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_current_length], + batch.next_token_logits, + speculate, + batch.speculative_ids, + batch.speculative_logits, + ) + + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill: + indices = batch.cu_seqlen_prefill[1:] - 1 + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[ + batch.prefill_cache_indices + ][indices] + else: + batch.position_ids = batch.position_ids[indices] + + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = ( + batch.adapter_meta.adapter_indices[indices] + ) + # For each member of the batch + # Cumulative length + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + if batch.speculative_logits is not None: + for i in range(len(batch)): + batch.all_input_ids_tensor[ + i, + batch.cache_lengths[i] + + batch.input_lengths[i] : batch.cache_lengths[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] + else: + index = batch.cache_lengths_tensor + batch.input_lengths_tensor + batch_idx = torch.arange( + 0, + batch.all_input_ids_tensor.shape[0], + dtype=torch.long, + device=batch.input_lengths_tensor.device, + ) + batch.all_input_ids_tensor.index_put_( + (batch_idx, index.long()), next_input_ids + ) + + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.speculative_ids = speculative_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += ( + batch.input_lengths_tensor + accepted_ids - 1 + ) + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) + batch.slot_indices += accepted_ids + + # Does a HPU <-> CPU sync internally + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments( + batch.adapter_meta.adapter_indices + ) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + prev_batches.append( + { + "next_token_ids": next_input_ids, + "next_token_logprobs": next_token_logprobs, + "accepted_ids": accepted_ids, + } + ) + idx = len(prev_batches) - 1 + if batch.speculative_logits is not None: + accepted_ids_cpu = accepted_ids.cpu() + + for req_idx, req in enumerate(batch.requests): + new_input_length = 1 + if batch.speculative_logits is not None: + new_cache_length = ( + batch.cache_lengths[req_idx] + + batch.input_lengths[req_idx] + + accepted_ids_cpu[req_idx] + - 1 + ) + else: + new_cache_length = ( + batch.cache_lengths[req_idx] + batch.input_lengths[req_idx] + ) + batch.cache_lengths[req_idx] = new_cache_length + batch.max_input_length = max( + batch.max_input_length, new_input_length + ) + batch.input_lengths[req_idx] = new_input_length + current_length = new_cache_length + new_input_length + batch.max_current_length = max( + batch.max_current_length, current_length + ) + + requests_to_generate.append( + { + "idx": idx, + "request_id": req.id, + "prefix_offset": batch.prefix_offsets[req_idx], + "read_offset": batch.read_offsets[req_idx], + "stopping_criteria": batch.stopping_criterias[req_idx], + "all_input_ids": batch.all_input_ids[req_idx], + "do_sample": batch.next_token_chooser.do_sample[req_idx], + "seed": batch.next_token_chooser.seeds[req_idx], + "top_n_tokens": batch.top_n_tokens[req_idx], + "top_token_ids": batch_top_token_ids[req_idx], + "top_token_logprobs": batch_top_token_logprobs[req_idx], + } + ) + if prefill: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None + batch.next_token_logits = None + batch.speculative_ids = None + + htorch.core.mark_step() + # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: batch = self.batch_type.concatenate(batches) else: @@ -1690,9 +1981,16 @@ class FlashCausalLM(Model): start = time.time_ns() prefill = batch.prefilling if prefill: - batch.prepare_for_prefill() + if self.bucketing_ctx is not None: + batch.prepare_for_prefill( + self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length) + ) + else: + batch.prepare_for_prefill(batch.max_input_length) else: - batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) + batch.prepare_for_decode( + self.dtype, self.use_contiguous_pa, self.bucketing_ctx + ) prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta @@ -1724,7 +2022,7 @@ class FlashCausalLM(Model): out, speculative_logits = self.forward(batch, adapter_data) if prefill: - next_token_logits = ( + batch.next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) if speculative_logits is not None: @@ -1733,413 +2031,144 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: prefill_logprobs = None - next_token_logits = out + batch.next_token_logits = out + batch.speculative_logits = speculative_logits - finished_prefilling = True - next_chunk_lengths = [] - current_prefilling_mask = batch.prefilling_mask - if prefill: - finished_prefilling = True - next_prefilling_mask = [False] * len(batch) - - batch.prefilling = not finished_prefilling - batch.prefilling_mask = next_prefilling_mask - - speculate = get_speculate() - ( - next_input_ids, - next_token_logprobs, - logprobs, - accepted_ids, - speculative_ids, - ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], - next_token_logits, - speculate, - batch.speculative_ids, - speculative_logits, - ) - - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids - ) - - # Since we are done prefilling, all the tensors that were concatenating values for all the requests - # instantly become of shape [BATCH_SIZE] - if prefill and finished_prefilling: - indices = batch.cu_seqlen_prefill[1:] - 1 - # pad in left - if batch.prefill_cache_indices is not None: - batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ - indices - ] - else: - batch.position_ids = batch.position_ids[indices] - - batch.slot_indices = batch.slot_indices[indices] - batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ - indices - ] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.all_input_ids, - accepted_ids, - current_prefilling_mask, - batch.prefilling_mask, - ) - - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a HPU <-> CPU sync - # It is faster if we delay this sync for the maximum amount of time - - # For each member of the batch - # Cumulative length - cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) - torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - cumulative_length = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - all_input_ids, - n_accepted_ids, - request_was_prefilling, - request_is_prefilling, - ) in enumerate(iterator): - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[ - i, cache_length + 1 : cache_length + input_length + 1 - ] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - # If the device does not support triton, we copy one by one - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request - batch.all_input_ids_tensor[ - i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] - + batch.input_lengths[i] - + accepted_ids[i], - ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - cumulative_length += input_length - - # Update values - # These values can be updated without a HPU -> CPU sync - if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] - batch.speculative_ids = speculative_ids - if batch.position_ids.dim() == 2: - # Qwen2_vl case: - batch.position_ids += accepted_ids.unsqueeze(-1) - else: - batch.position_ids += accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 - batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids - - if prefill and prefill_logprobs: - # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) - torch.log_softmax(out, -1, out=out) - prefill_logprobs_tensor = out - prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) - ) - # HPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - - # Does a HPU <-> CPU sync internally - if prefill and finished_prefilling: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) - - # HPU <-> CPU sync - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids.tolist() - accepted_ids = accepted_ids.tolist() - - # Update values if we need to continue prefilling - # This represents the `else` case of the `Update values` if above - # but since this require the `next_token_ids` to be on CPU, it is better to do it here - if prefill and not finished_prefilling: - # Speculation must be ignored while we prefill even with chunking - # it simplifies everything - assert batch.speculative_ids is None - - all_postfix_ids = [] - for i, ( - request_prefilling, - next_token_id, - all_input_ids, - cache_length, - input_length, - next_chunk_length, - ) in enumerate( - zip( - batch.prefilling_mask, - next_token_ids, - batch.all_input_ids, - batch.cache_lengths, - batch.input_lengths, - next_chunk_lengths, - ) - ): - if request_prefilling: - next_cache_length = cache_length + input_length - # Get new prompt IDs to prefill - postfix_ids = all_input_ids[ - next_cache_length : next_cache_length + next_chunk_length - ] - else: - # This request is done prefilling, the new id is the one selected the sampling method - postfix_ids = [next_token_id] - - all_postfix_ids.append(postfix_ids) - - batch.input_ids = all_postfix_ids + # HPU->CPU sync + for prev_batch in prev_batches: + prev_batch["next_token_logprobs"] = prev_batch[ + "next_token_logprobs" + ].tolist() + prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist() + prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist() start_decode = time.time_ns() - + # Stage 3. Finish and return previous generations # Results generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - batch.stopping_criterias, - batch.all_input_ids, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - batch.top_n_tokens, - current_prefilling_mask, - batch.prefilling_mask, - accepted_ids, - batch_top_token_ids, - batch_top_token_logprobs, - ) - + stopped = len(requests_to_generate) > 0 # Reset max_input_length batch.max_input_length = 0 # For each member of the batch - index = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - top_n_tokens, - request_was_prefilling, - request_is_prefilling, - n_accepted_ids, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Compute logprobs first as, even though we might skip the token, - # it can still be required to compute the logprobs - # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need - # this state to be stable - if request.id % self.world_size == self.rank: - # Prefill - if request_was_prefilling and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - if not request_is_prefilling: - # The request is dones prefilling, meaning that we started generating new tokens - # The last logprob is a logprob for a generated token that was not part of the prompt - # We need to remove it - out_end_index -= 1 + indexs = [0] * len(prev_batches) + idx_accept_ids = [0] * len(prev_batches) + for i, req_data in enumerate(requests_to_generate): + idx = req_data["idx"] + request_id = req_data["request_id"] + prefix_offset = req_data["prefix_offset"] + read_offset = req_data["read_offset"] + stopping_criteria = req_data["stopping_criteria"] + all_input_ids = req_data["all_input_ids"] + do_sample = req_data["do_sample"] + seed = req_data["seed"] + top_n_tokens = req_data["top_n_tokens"] + n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]] + top_token_ids = req_data["top_token_ids"] + top_token_logprobs = req_data["top_token_logprobs"] + # Append next token to all tokens + next_token_texts = [] + left = 0 - request_prefill_logprobs = prefill_logprobs[ - out_start_index:out_end_index - ] - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[ - cache_length + 1 : cache_length + input_length + 1 - ] + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + current_stopped = False + index = indexs[idx] + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = prev_batches[idx]["next_token_ids"][j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - if past_prefill_logprob_tokens is None: - # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * ( - cache_length + 1 - ) + request_prefill_logprobs - prefill_token_ids = ( - all_input_ids[: cache_length + 1] + prefill_token_ids - ) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - - prefill_logprob_tokens = Tokens( - prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - is_special=[], - ) - if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = ( - past_prefill_logprob_tokens + prefill_logprob_tokens - ) - - batch.prefill_logprob_tokens[i] = prefill_logprob_tokens + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break else: - batch.prefill_logprob_tokens[i] = None + current_stopped = False + stopped = stopped and current_stopped - # If it is, the tokens we decoded should be ignored - if request_is_prefilling: - # Make sure that we do not stop as even though this request did not create a token, it is still - # processing - stopped = False - new_input_length = next_chunk_lengths[i] - new_cache_length = cache_length + input_length - else: - new_input_length = 1 - new_cache_length = cache_length + input_length + n_accepted_ids - 1 - # Append next token to all tokens - next_token_texts = [] - left = 0 + _next_token_ids = prev_batches[idx]["next_token_ids"][ + index : index + n_accepted_ids - left + ] + _next_token_logprobs = prev_batches[idx]["next_token_logprobs"][ + index : index + n_accepted_ids - left + ] - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( + # Shard generations + # All generations will be appended in the rust sharded client + if request_id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( all_input_ids, - prefix_offset, - read_offset, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) + else: + generated_text = None - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - - # Shard generations - # All generations will be appended in the rust sharded client - if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, ) - else: - generated_text = None + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None + generation = Generation( + request_id, + None, + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) - generation = Generation( - request.id, - batch.prefill_logprob_tokens[i], - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) + generations.append(generation) # accept each new token for this specific request since we may # have more than one new token per request with speculative decoding @@ -2151,12 +2180,8 @@ class FlashCausalLM(Model): ) # Update values - index += n_accepted_ids - batch.cache_lengths[i] = new_cache_length - batch.max_input_length = max(batch.max_input_length, new_input_length) - batch.input_lengths[i] = new_input_length - current_length = new_cache_length + new_input_length - batch.max_current_length = max(batch.max_current_length, current_length) + indexs[idx] += n_accepted_ids + idx_accept_ids[idx] += 1 batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset @@ -2168,14 +2193,6 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - if prefill and finished_prefilling: - # We do not need prefill tensors anymore - batch.cu_seqlen_prefill = None - batch.prefill_cache_indices = None - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None - forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 208ab358..c885816b 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + prepare_for_decode, ) -from text_generation_server.models.globals import PREFIX_CACHING +from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata import habana_frameworks.torch as htorch +from text_generation_server.utils.import_utils import ( + synchronize, +) +import torch.nn.functional as F tracer = trace.get_tracer(__name__) @@ -375,6 +380,91 @@ class FlashVlmCausalLM(FlashCausalLM): def max_past(self) -> Optional[int]: return getattr(self.model.text_model, "max_past", None) + def warmup_decode( + self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch + ): + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) + if batch.position_ids is not None and batch.position_ids.dim() == 2: + # qwen2_vl and qwen2_5_vl case + position_ids = position_ids.unsqueeze(-1).repeat( + (1, batch.position_ids.shape[-1]) + ) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + batch_size, + bucketing_ctx=None, + ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots_tensor, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=hpu_attention_meta, + lm_head_indices=None, + pixel_values=None, + pixel_attention_mask=None, + image_sizes=None, + image_grid_thw=None, + ) + + def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + if batch_size > block_num: + continue + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + def forward( self, batch: FlashVlmCausalLMBatch, @@ -450,17 +540,75 @@ class FlashVlmCausalLM(FlashCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + if position_ids.dim() == 2: + # qwen2_vl and qwen2_5_vl case + position_ids = F.pad( + position_ids, + (0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]), + value=1, + ) + else: + position_ids = F.pad( + position_ids, + (0, (padded_bs - orig_bs) * input_seq.shape[-1]), + value=1, + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -476,8 +624,6 @@ class FlashVlmCausalLM(FlashCausalLM): image_grid_thw=batch.image_grid_thw, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None if batch.pixel_attention_mask is not None: @@ -486,4 +632,6 @@ class FlashVlmCausalLM(FlashCausalLM): batch.image_sizes = None if batch.image_grid_thw is not None: batch.image_grid_thw = None - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index e034ed49..c1ea36f2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -11,7 +11,9 @@ from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) - +from text_generation_server.models.flash_causal_lm import ( + prepare_for_decode, +) from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLMBatch, FlashVlmCausalLM, @@ -19,6 +21,13 @@ from text_generation_server.models.flash_vlm_causal_lm import ( from text_generation_server.pb import generate_pb2 from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata import habana_frameworks.torch as htorch +from loguru import logger +from text_generation_server.models.globals import BLOCK_SIZE +from text_generation_server.utils.import_utils import ( + synchronize, +) +import torch.nn.functional as F +from text_generation_server.utils.log import log_master tracer = trace.get_tracer(__name__) @@ -196,7 +205,178 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): return batch +def generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling +): + if cross_attention_states is None: + return None, None, None + device = cross_attention_states.device + indices_list = [] + if prefilling: + for i in image_indices: + indices_list.append( + torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device) + ) + indices = torch.cat(indices_list, dim=0) + else: + indices = image_indices[:] + return indices, seqlen.input_lengths.index_select(0, image_indices) + + class FlashMllamaCausalLM(FlashVlmCausalLM): + def warmup_decode( + self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch + ): + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + batch_size, + bucketing_ctx=None, + ) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = image_indices.repeat(batch_size) + cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, 1, False + ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots_tensor, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=hpu_attention_meta, + lm_head_indices=None, + adapter_data=None, + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, + ) + + def warmup_prefill( + self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch + ): + input_ids = torch.zeros( + prompt_len, dtype=batch.input_ids.dtype, device=self.device + ).repeat(batch_size) + position_ids = torch.arange( + prompt_len, dtype=batch.position_ids.dtype, device=self.device + ).repeat(batch_size) + max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).reshape(batch_size, -1) + slot_acc = [] + for i in range(batch_size): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) + + input_lengths = ( + torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + lm_head_indices = input_lengths - 1 + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = image_indices.repeat(batch_size) + cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, prompt_len, True + ) + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=None, + lm_head_indices=lm_head_indices, + adapter_data=None, + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, + ) + + def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + warmup_times = 3 + self.bucketing_ctx.generate_prompt_buckets() + for i, (batch_size, seq_len) in enumerate( + reversed(self.bucketing_ctx.prompt_buckets) + ): + log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + if batch_size > block_num: + continue + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + def forward( self, batch: FlashMllamaCausalLMBatch, @@ -263,12 +443,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - if batch.pixel_values is not None: cross_attention_states = self.model.vision_forward( pixel_values=batch.pixel_values, @@ -281,11 +455,82 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] + padded_input_len = input_ids.view(orig_bs, -1).shape[-1] + image_indices = torch.tensor(batch.image_indices, device=self.device) + if padded_bs != input_lengths.shape[0]: + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0 + ) + position_ids = F.pad( + position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1 + ) + slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + if cross_attention_states is not None: + cross_attention_states = F.pad( + cross_attention_states, + (0, 0, 0, 0, 0, (padded_bs - orig_bs)), + value=0, + ) + if len(image_indices) != 0: + pad_indices = torch.arange(orig_bs, padded_bs, device=self.device) + image_indices = torch.cat((image_indices, pad_indices), dim=0) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, + image_indices, + seqlen, + padded_input_len, + batch.prefilling, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -295,14 +540,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, # TODO list adapter_data=None, - image_indices=batch.image_indices[:], + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 6da2b51d..1628a00b 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator { (required_blocks, repeats) }; - let tokens = tokens as usize; + let mut tokens = tokens as usize; if required_blocks > self.free_blocks.len() as u32 { None } else { @@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator { .split_off(self.free_blocks.len() - required_blocks as usize); if self.is_hpu_device { blocks.sort(); + // need 1 slot for ping-pong optimization + tokens += 1; } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);