diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index d6e512c0..648b010a 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,7 +9,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: - input_lengths: torch.Tensor + postfix_lengths: torch.Tensor prefix_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] @@ -18,16 +18,16 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: def __init__( self, - input_lengths, + postfix_lengths, prefix_lengths, cu_seqlen_q=None, max_q=None, max_k=None, ): - self.input_lengths = input_lengths + self.postfix_lengths = postfix_lengths self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape + device = self.postfix_lengths.device + shape = self.postfix_lengths.shape if cu_seqlen_q is None: cu_seqlen_q = torch.arange( shape[0] + 1, @@ -43,7 +43,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths + total = self.postfix_lengths + self.prefix_lengths torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q @@ -59,7 +59,7 @@ else: @dataclass class Seqlen: - input_lengths: torch.Tensor + postfix_lengths: torch.Tensor prefix_lengths: torch.Tensor cu_seqlen_q: torch.Tensor max_q: int diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 33fe30a8..bb35886c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -143,9 +143,6 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -162,8 +159,14 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor: torch.Tensor # Lengths of all generations present in the batch - input_lengths: List[int] - input_lengths_tensor: torch.Tensor + postfix_lengths: List[int] + postfix_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lengths: List[int] + prefix_lengths_tensor: torch.Tensor + prompt_lengths: List[int] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -225,10 +228,13 @@ class FlashCausalLMBatch(Batch): slot_indices = [] prefill_cache_indices = [] - input_lengths = [] + prefix_lengths = [] + postfix_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] + all_postfix_ids = [] prefix_ids = [] requests_idx_mapping = {} @@ -257,7 +263,6 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] - prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -266,37 +271,39 @@ class FlashCausalLMBatch(Batch): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - orig_input_length = len(tokenized_input) + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) - prefix_len = r.prefix_len + prefix_length = r.prefix_len assert ( - prefix_len <= orig_input_length - ), f"Prefix {prefix_len} vs input {orig_input_length}" - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + prefix_length <= prompt_length + ), f"Prefix {prefix_length} vs input {prompt_length}" + if prefix_length == prompt_length: + assert prefix_length > 0 + prefix_length -= 1 # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] + prefix_ids.append(tokenized_input[:prefix_length]) + postfix_ids = tokenized_input[prefix_length:] - input_length = len(tokenized_input) - input_lengths.append(input_length) + postfix_length = len(postfix_ids) + postfix_lengths.append(postfix_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(postfix_length - 5) + read_offsets.append(postfix_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 + prefix_length, prompt_length, dtype=torch.int32 ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) + cu_seqlen_prefill.append(cumulative_length + postfix_length) next_token_chooser_parameters.append(r.parameters) @@ -309,7 +316,7 @@ class FlashCausalLMBatch(Batch): ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((input_length,), adapter_index)) + adapter_indices_list.append(torch.full((postfix_length,), adapter_index)) adapter_set.add(adapter_index) # Paged attention @@ -318,11 +325,11 @@ class FlashCausalLMBatch(Batch): speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + block_tokens = prompt_length + max_new_tokens - 1 + speculative_length # Tokens that need to be mapped to slots. We don't need slots for the # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + slot_tokens = postfix_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: @@ -338,19 +345,19 @@ class FlashCausalLMBatch(Batch): else: request_blocks = r.blocks request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length + prefix_length: #: orig_input_length + max_new_tokens + speculative_length ] block_tables.append(request_blocks) slots.extend(request_slots) - prefix_lens.append(prefix_len) + prefix_lengths.append(prefix_length) num_blocks += len(request_blocks) start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( cumulative_slot_tokens, - cumulative_slot_tokens + input_length, + cumulative_slot_tokens + postfix_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -358,8 +365,8 @@ class FlashCausalLMBatch(Batch): # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, + cumulative_length + max(0, postfix_length - sliding_window), + cumulative_length + postfix_length, dtype=torch.int64, ) prefill_cache_indices.append(request_prefill_cache_indices) @@ -370,14 +377,16 @@ class FlashCausalLMBatch(Batch): if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 + prefill_out_cumulative_length + postfix_length - 1 ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length + prefill_cu_outlens.append( + prefill_out_cumulative_length + postfix_length + ) + prefill_out_cumulative_length += postfix_length else: prefill_head_indices.append( torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 + [cumulative_length + postfix_length - 1], dtype=torch.int32 ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) @@ -385,12 +394,13 @@ class FlashCausalLMBatch(Batch): prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length + cumulative_length += postfix_length cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) + max_seqlen = max(max_seqlen, postfix_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( - max_length, input_length + max_new_tokens + speculative_length + max_length, + prefix_length + postfix_length + max_new_tokens + speculative_length, ) adapter_indices = torch.cat(adapter_indices_list).to( @@ -415,13 +425,13 @@ class FlashCausalLMBatch(Batch): ) if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) + input_ids = np.concatenate(all_postfix_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) if sliding_window is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: - input_ids = all_input_ids[0] + input_ids = all_postfix_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] if sliding_window is not None: @@ -436,8 +446,11 @@ class FlashCausalLMBatch(Batch): prefill_cache_indices.to(device) if sliding_window is not None else None ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device + postfix_lengths_tensor = torch.tensor( + postfix_lengths, dtype=torch.int32, device=device + ) + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device ) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) @@ -470,7 +483,9 @@ class FlashCausalLMBatch(Batch): for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + prefix_lengths_tensor = torch.tensor( + prefix_lengths, dtype=torch.int32, device=device + ) return cls( batch_id=pb.id, @@ -485,14 +500,16 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + prefix_lengths=prefix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + postfix_lengths=postfix_lengths, + postfix_lengths_tensor=postfix_lengths_tensor, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -556,8 +573,8 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] prefix_ids = [] - input_lengths = [] - prefix_lens = [] + postfix_lengths = [] + prefix_lengths = [] prefix_offsets = [] read_offsets = [] @@ -578,15 +595,15 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) # Get length - request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] + request_input_length = self.postfix_lengths[idx] + prefix_length = self.prefix_lengths[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + postfix_lengths.append(request_input_length) + prefix_lengths.append(prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -629,9 +646,9 @@ class FlashCausalLMBatch(Batch): adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] + postfix_lengths_tensor = self.postfix_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] + prefix_lengths_tensor = self.prefix_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -666,10 +683,10 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + postfix_lengths_tensor=postfix_lengths_tensor, + prefix_lengths=prefix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -720,7 +737,7 @@ class FlashCausalLMBatch(Batch): + speculative_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( - b.input_lengths, b.stopping_criterias + b.postfix_lengths, b.stopping_criterias ) ), ) @@ -729,13 +746,15 @@ class FlashCausalLMBatch(Batch): position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) + prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( + total_batch_size + ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -753,11 +772,11 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] - prefix_lens = [] + prefix_lengths = [] all_input_ids = [] prefix_ids = [] - input_lengths = [] + postfix_lengths = [] prefix_offsets = [] read_offsets = [] @@ -790,7 +809,7 @@ class FlashCausalLMBatch(Batch): input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots @@ -817,16 +836,16 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + prefix_lengths.extend(batch.prefix_lengths) all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) - input_lengths.extend(batch.input_lengths) + postfix_lengths.extend(batch.postfix_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) @@ -872,15 +891,15 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + prefix_lengths=prefix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + postfix_lengths=postfix_lengths, + postfix_lengths_tensor=postfix_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -1100,9 +1119,9 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = [max_s] * bs + postfix_lengths = [max_s] * bs prefix_lengths = [0] * bs - input_lengths_tensor = ( + postfix_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) @@ -1114,8 +1133,8 @@ class FlashCausalLM(Model): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lengths, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths, ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, @@ -1143,7 +1162,7 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths_tensor, + "postfix_lengths": postfix_lengths_tensor, "prefix_lengths": prefix_lengths_tensor, "state": state, "graph": graph, @@ -1154,12 +1173,12 @@ class FlashCausalLM(Model): with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, + postfix_lengths_tensor=postfix_lengths_tensor, state=state, - prefix_lens_tensor=prefix_lengths_tensor, + prefix_lengths_tensor=prefix_lengths_tensor, ): seqlen = Seqlen( - input_lengths=input_lengths_tensor, + postfix_lengths=postfix_lengths_tensor, prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=None, max_q=1, @@ -1183,7 +1202,7 @@ class FlashCausalLM(Model): with torch.cuda.graph(graph, pool=MEM_POOL): seqlen = Seqlen( - input_lengths=input_lengths_tensor, + postfix_lengths=postfix_lengths_tensor, prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=None, max_q=1, @@ -1340,15 +1359,17 @@ class FlashCausalLM(Model): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + prefix_lengths_tensor = torch.zeros( + seqlen, dtype=torch.int32, device=self.device + ) cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) max_s = seqlen seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=1, max_k=seqlen, @@ -1379,7 +1400,7 @@ class FlashCausalLM(Model): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor + postfix_lengths = batch.postfix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1396,11 +1417,11 @@ class FlashCausalLM(Model): position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + postfix_lengths = ( + postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + prefix_lengths_tensor = ( + batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1421,8 +1442,8 @@ class FlashCausalLM(Model): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor + postfix_lengths = batch.postfix_lengths_tensor + prefix_lengths_tensor = batch.prefix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1444,19 +1465,19 @@ class FlashCausalLM(Model): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + postfix_lengths_tensor=postfix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() + max_k = (postfix_lengths + prefix_lengths_tensor).max().item() seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -1485,8 +1506,8 @@ class FlashCausalLM(Model): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables @@ -1499,16 +1520,18 @@ class FlashCausalLM(Model): # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["postfix_lengths"].zero_() + cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor + cuda_graph["prefix_lengths"][ + : prefix_lengths_tensor.shape[0] + ] = prefix_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - prefix_lens_tensor=cuda_graph["prefix_lengths"], + postfix_lengths_tensor=cuda_graph["postfix_lengths"], + prefix_lengths_tensor=cuda_graph["prefix_lengths"], state=cuda_graph["state"], ): # Replay the graph @@ -1586,7 +1609,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)], next_token_logits, speculate, batch.speculative_ids, @@ -1619,7 +1642,12 @@ class FlashCausalLM(Model): stopped = True # Zipped iterator - iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + iterator = zip( + batch.prefix_lengths, + batch.postfix_lengths, + batch.all_input_ids, + accepted_ids, + ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a GPU <-> CPU sync @@ -1627,10 +1655,15 @@ class FlashCausalLM(Model): # For each member of the batch index = 0 - for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): + for i, ( + prefix_length, + postfix_length, + all_input_ids, + n_accepted_ids, + ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length - end_index = cumulative_length + input_length + end_index = cumulative_length + postfix_length if prefill: # Indexing metadata @@ -1662,16 +1695,18 @@ class FlashCausalLM(Model): ] for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] + batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = ( + next_input_ids[index] + ) index += 1 - cumulative_length += input_length + cumulative_length += postfix_length # Update values batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids + batch.postfix_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1702,7 +1737,7 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( batch.requests, - batch.input_lengths, + batch.postfix_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, @@ -1867,9 +1902,9 @@ class FlashCausalLM(Model): ) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + batch.postfix_lengths[i] = input_length + n_accepted_ids + if batch.postfix_lengths[i] > batch.max_seqlen: + batch.max_seqlen = batch.postfix_lengths[i] batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1893,8 +1928,8 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, + postfix_lengths_tensor: torch.Tensor, + prefix_lengths_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": @@ -1905,7 +1940,7 @@ class FlashCausalLM(Model): use_prefill_with_paged_kv_state, ) - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + # has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths) if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( @@ -1914,12 +1949,12 @@ class FlashCausalLM(Model): ), # block_tables=block_tables_to_ragged( # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, + # postfix_lengths=postfix_lengths, + # prefix_lengths=prefix_lengths, # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -1928,10 +1963,10 @@ class FlashCausalLM(Model): window_left=self.sliding_window, ) else: - assert input_lengths_tensor is not None + assert postfix_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, @@ -1943,19 +1978,21 @@ class FlashCausalLM(Model): def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] + *, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int] ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) + assert len(postfix_lengths) == len(prefix_lengths) - total_len = sum(input_lengths) + sum(prefix_lens) + total_len = sum(postfix_lengths) + sum(prefix_lengths) block_tables_ragged = torch.empty( total_len, dtype=torch.int32, device=block_tables.device ) offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length + for i, (input_length, prefix_length) in enumerate( + zip(postfix_lengths, prefix_lengths) + ): + seq_len = prefix_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len