diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index 648b010a..8f9d93a1 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,8 +9,8 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: - postfix_lengths: torch.Tensor - prefix_lengths: torch.Tensor + input_lengths: torch.Tensor + cache_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] max_q: int @@ -18,16 +18,16 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: def __init__( self, - postfix_lengths, - prefix_lengths, + input_lengths, + cache_lengths, cu_seqlen_q=None, max_q=None, max_k=None, ): - self.postfix_lengths = postfix_lengths - self.prefix_lengths = prefix_lengths - device = self.postfix_lengths.device - shape = self.postfix_lengths.shape + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape if cu_seqlen_q is None: cu_seqlen_q = torch.arange( shape[0] + 1, @@ -43,7 +43,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - total = self.postfix_lengths + self.prefix_lengths + total = self.input_lengths + self.cache_lengths torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q @@ -59,8 +59,8 @@ else: @dataclass class Seqlen: - postfix_lengths: torch.Tensor - prefix_lengths: torch.Tensor + input_lengths: torch.Tensor + cache_lengths: torch.Tensor cu_seqlen_q: torch.Tensor max_q: int max_k: int diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6e8a0097..9a34dfc5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -150,7 +150,7 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward before staying set in decode slots: Optional[torch.Tensor] - max_postfix_length: int + max_input_length: int max_current_length: int # Whether this batch contains at least one request that is prefilling @@ -181,13 +181,13 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor: torch.Tensor # Lengths of all generations present in the batch - postfix_lengths: List[int] + input_lengths: List[int] # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lengths: List[int] + cache_lengths: List[int] prompt_lengths: List[int] # Will be set by `generate_token` and reset after each prefill forward before staying set in decode - postfix_lengths_tensor: Optional[torch.Tensor] - prefix_lengths_tensor: Optional[torch.Tensor] + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] prompt_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] @@ -252,8 +252,8 @@ class FlashCausalLMBatch(Batch): ) -> "FlashCausalLMBatch": speculate = get_speculate() - prefix_lengths = [] - postfix_lengths = [] + cache_lengths = [] + input_lengths = [] prompt_lengths = [] prefix_offsets = [] read_offsets = [] @@ -267,7 +267,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens = [] num_blocks = 0 - max_postfix_length = 0 + max_input_length = 0 max_current_length = 0 max_length = 0 max_blocks = 0 @@ -284,28 +284,26 @@ class FlashCausalLMBatch(Batch): prompt_length = len(tokenized_input) prompt_lengths.append(prompt_length) - prefix_length = r.prefix_len - postfix_length = r.postfix_len + cache_length = r.prefix_len + input_length = r.postfix_len assert ( - prefix_length <= prompt_length - ), f"Prefix {prefix_length} vs input {prompt_length}" - if prefix_length == prompt_length: + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: assert False, "unreachable" - if prefix_length + postfix_length < prompt_length: + if cache_length + input_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 assert get_support_chunking() - assert postfix_length > 0 + assert input_length > 0 - prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[ - prefix_length : prefix_length + postfix_length - ] + prefix_ids.append(tokenized_input[:cache_length]) + postfix_ids = tokenized_input[cache_length : cache_length + input_length] assert ( - len(postfix_ids) == postfix_length + len(postfix_ids) == input_length ), "Rust and Python tokenizers are not aligned" - postfix_lengths.append(postfix_length) + input_lengths.append(input_length) prefix_offsets.append(prompt_length - 5) read_offsets.append(prompt_length) @@ -341,13 +339,13 @@ class FlashCausalLMBatch(Batch): block_tables.append(request_blocks) - prefix_lengths.append(prefix_length) + cache_lengths.append(cache_length) num_blocks += len(request_blocks) # Update max_blocks = max(max_blocks, len(request_blocks)) - max_postfix_length = max(max_postfix_length, postfix_length) - max_current_length = max(max_current_length, prefix_length + postfix_length) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) max_length = max( max_length, prompt_length + max_new_tokens + speculative_length, @@ -390,13 +388,13 @@ class FlashCausalLMBatch(Batch): input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lengths=prefix_lengths, - max_postfix_length=max_postfix_length, + cache_lengths=cache_lengths, + max_input_length=max_input_length, max_current_length=max_current_length, prefilling=True, prefilling_mask=[True] * len(pb.requests), prefill_logprob_tokens=[None] * len(pb.requests), - postfix_lengths=postfix_lengths, + input_lengths=input_lengths, prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, @@ -420,8 +418,8 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - prefix_lengths_tensor=None, - postfix_lengths_tensor=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, adapter_meta=None, ) @@ -460,7 +458,7 @@ class FlashCausalLMBatch(Batch): # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_postfix_length = 0 + max_input_length = 0 max_current_length = 0 requests = [] @@ -470,8 +468,8 @@ class FlashCausalLMBatch(Batch): input_ids = [] prompt_lengths = [] - postfix_lengths = [] - prefix_lengths = [] + input_lengths = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] @@ -499,19 +497,19 @@ class FlashCausalLMBatch(Batch): prefilling_mask.append(request_prefilling) # Get length - request_postfix_length = self.postfix_lengths[idx] - request_prefix_length = self.prefix_lengths[idx] - max_postfix_length = max(max_postfix_length, request_postfix_length) + request_input_length = self.input_lengths[idx] + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) max_current_length = max( - max_current_length, request_prefix_length + request_postfix_length + max_current_length, request_cache_length + request_input_length ) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) prompt_lengths.append(self.prompt_lengths[idx]) - postfix_lengths.append(request_postfix_length) - prefix_lengths.append(request_prefix_length) + input_lengths.append(request_input_length) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -544,12 +542,12 @@ class FlashCausalLMBatch(Batch): # Set slice slot_filtering_indices[ self.slot_indices[idx] : self.slot_indices[idx] - + request_postfix_length + + request_input_length + remaining_tokens - 1 ] = True - cumulative_max_length += request_postfix_length + remaining_tokens - 1 + cumulative_max_length += request_input_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -567,17 +565,17 @@ class FlashCausalLMBatch(Batch): position_ids = None slot_indices = None slots = None - prefix_lengths_tensor = None - postfix_lengths_tensor = None + cache_lengths_tensor = None + input_lengths_tensor = None adapter_meta = None else: # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] - postfix_lengths_tensor = self.postfix_lengths_tensor[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] - prefix_lengths_tensor = self.prefix_lengths_tensor[indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) @@ -605,7 +603,7 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_postfix_length=max_postfix_length, + max_input_length=max_input_length, max_current_length=max_current_length, prefilling=self.prefilling, prefilling_mask=prefilling_mask, @@ -615,10 +613,10 @@ class FlashCausalLMBatch(Batch): prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, - postfix_lengths=postfix_lengths, - postfix_lengths_tensor=postfix_lengths_tensor, - prefix_lengths=prefix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -647,7 +645,7 @@ class FlashCausalLMBatch(Batch): total_slots = 0 max_blocks = 0 max_length = 0 - max_postfix_length = 0 + max_input_length = 0 max_current_length = 0 for b in batches: total_batch_size += len(b) @@ -659,7 +657,7 @@ class FlashCausalLMBatch(Batch): speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_postfix_length = max(max_postfix_length, b.max_postfix_length) + max_input_length = max(max_input_length, b.max_input_length) max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, @@ -680,8 +678,8 @@ class FlashCausalLMBatch(Batch): position_ids = None slots = None slot_indices = None - prefix_lengths_tensor = None - postfix_lengths_tensor = None + cache_lengths_tensor = None + input_lengths_tensor = None adapter_meta = None adapter_segment_builder = None else: @@ -689,10 +687,10 @@ class FlashCausalLMBatch(Batch): position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size ) - prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( total_batch_size ) total_indices_size = sum( @@ -718,12 +716,12 @@ class FlashCausalLMBatch(Batch): ) block_tables = [] - prefix_lengths = [] + cache_lengths = [] all_input_ids = [] prefix_ids = [] prompt_lengths = [] - postfix_lengths = [] + input_lengths = [] prefix_offsets = [] read_offsets = [] @@ -773,9 +771,7 @@ class FlashCausalLMBatch(Batch): slot_indices[start_index:end_index] = ( batch.slot_indices + cumulative_slots ) - postfix_lengths_tensor[start_index:end_index] = ( - batch.postfix_lengths_tensor - ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor slots[slots_start_index:slots_end_index] = batch.slots # Copy over adapter indices @@ -793,9 +789,7 @@ class FlashCausalLMBatch(Batch): batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - prefix_lengths_tensor[start_index:end_index] = ( - batch.prefix_lengths_tensor - ) + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Update cumulative_slots += len(batch.slots) @@ -806,12 +800,12 @@ class FlashCausalLMBatch(Batch): prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lengths.extend(batch.prefix_lengths) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) prompt_lengths.extend(batch.prompt_lengths) - postfix_lengths.extend(batch.postfix_lengths) + input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) @@ -860,10 +854,10 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lengths=prefix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_postfix_length=max_postfix_length, + max_input_length=max_input_length, max_current_length=max_current_length, prefilling=prefilling, prefilling_mask=prefilling_mask, @@ -873,8 +867,8 @@ class FlashCausalLMBatch(Batch): prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, - postfix_lengths=postfix_lengths, - postfix_lengths_tensor=postfix_lengths_tensor, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -918,30 +912,30 @@ class FlashCausalLMBatch(Batch): for i, ( r, - prefix_length, - postfix_length, + cache_length, + input_length, prompt_length, request_prefilling, blocks, ) in enumerate( zip( self.requests, - self.prefix_lengths, - self.postfix_lengths, + self.cache_lengths, + self.input_lengths, self.prompt_lengths, self.prefilling_mask, self.block_tables, ) ): - next_chunk_length = postfix_length + next_chunk_length = input_length # Position ids request_position_ids = torch.arange( - prefix_length, prefix_length + postfix_length, dtype=torch.int32 + cache_length, cache_length + input_length, dtype=torch.int32 ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + postfix_length) + cu_seqlen_prefill.append(cumulative_length + input_length) if not r.slots: request_slots = [ @@ -952,18 +946,18 @@ class FlashCausalLMBatch(Batch): else: request_slots = r.slots - request_slots = request_slots[prefix_length:] + request_slots = request_slots[cache_length:] request_slot_indices = torch.arange( cumulative_slot_tokens, - cumulative_slot_tokens + postfix_length, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, postfix_length - sliding_window), - cumulative_length + postfix_length, + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, dtype=torch.int64, ) @@ -976,16 +970,14 @@ class FlashCausalLMBatch(Batch): if prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( - prefill_out_cumulative_length + postfix_length - 1 + prefill_out_cumulative_length + input_length - 1 ) - prefill_cu_outlens.append( - prefill_out_cumulative_length + postfix_length - ) - prefill_out_cumulative_length += postfix_length + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length else: prefill_head_indices.append( torch.tensor( - [cumulative_length + postfix_length - 1], + [cumulative_length + input_length - 1], dtype=torch.int32, ) ) @@ -1038,8 +1030,8 @@ class FlashCausalLMBatch(Batch): self.prefill_cache_indices = ( prefill_cache_indices.to(device) if sliding_window is not None else None ) - self.postfix_lengths_tensor = torch.tensor( - self.postfix_lengths, dtype=torch.int32, device=device + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device ) if all_prefill_logprobs: @@ -1059,8 +1051,8 @@ class FlashCausalLMBatch(Batch): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices self.slots = torch.tensor(slots, dtype=torch.int64, device=device) - self.prefix_lengths_tensor = torch.tensor( - self.prefix_lengths, dtype=torch.int32, device=device + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device ) adapter_indices = torch.cat(adapter_indices_list).to( dtype=torch.int64, device=device @@ -1276,12 +1268,12 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - postfix_lengths = [max_s] * bs - prefix_lengths = [0] * bs - postfix_lengths_tensor = ( + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + input_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) - prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device ).repeat(bs) @@ -1290,8 +1282,8 @@ class FlashCausalLM(Model): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths, + input_lengths=input_lengths, + cache_lengths=cache_lengths, ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, @@ -1319,8 +1311,8 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "postfix_lengths": postfix_lengths_tensor, - "prefix_lengths": prefix_lengths_tensor, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, "state": state, "graph": graph, } @@ -1330,13 +1322,13 @@ class FlashCausalLM(Model): with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=None, - postfix_lengths_tensor=postfix_lengths_tensor, + input_lengths_tensor=input_lengths_tensor, state=state, - prefix_lengths_tensor=prefix_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): seqlen = Seqlen( - postfix_lengths=postfix_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_s, @@ -1359,8 +1351,8 @@ class FlashCausalLM(Model): with torch.cuda.graph(graph, pool=MEM_POOL): seqlen = Seqlen( - postfix_lengths=postfix_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_s, @@ -1517,8 +1509,8 @@ class FlashCausalLM(Model): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) # Dummy value, some models (starcoder2) don't accept `None`. - postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lengths_tensor = torch.zeros( + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros( seqlen, dtype=torch.int32, device=self.device ) cu_seqlen_prefill = torch.tensor( @@ -1526,8 +1518,8 @@ class FlashCausalLM(Model): ) max_s = seqlen seqlen = Seqlen( - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=1, max_k=seqlen, @@ -1558,7 +1550,7 @@ class FlashCausalLM(Model): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor + input_lengths = batch.input_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -1575,11 +1567,11 @@ class FlashCausalLM(Model): position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - postfix_lengths = ( - postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lengths_tensor = ( - batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1600,8 +1592,8 @@ class FlashCausalLM(Model): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor - prefix_lengths_tensor = batch.prefix_lengths_tensor + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -1623,19 +1615,19 @@ class FlashCausalLM(Model): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - postfix_lengths_tensor=postfix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (postfix_lengths + prefix_lengths_tensor).max().item() + max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -1664,8 +1656,8 @@ class FlashCausalLM(Model): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables @@ -1678,18 +1670,18 @@ class FlashCausalLM(Model): # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["postfix_lengths"].zero_() - cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][ - : prefix_lengths_tensor.shape[0] - ] = prefix_lengths_tensor + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - postfix_lengths_tensor=cuda_graph["postfix_lengths"], - prefix_lengths_tensor=cuda_graph["prefix_lengths"], + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], state=cuda_graph["state"], ): # Replay the graph @@ -1775,13 +1767,13 @@ class FlashCausalLM(Model): batch_budget = get_max_prefill_tokens() - (len(batch) - 1) # We reverse to prioritize older requests # zip() is not reversible so reverse the underlying lists instead - for prefix_length, postfix_length, prompt_length in zip( - reversed(batch.prefix_lengths), - reversed(batch.postfix_lengths), + for cache_length, input_length, prompt_length in zip( + reversed(batch.cache_lengths), + reversed(batch.input_lengths), reversed(batch.prompt_lengths), ): remaining_prefill_tokens = max( - prompt_length - prefix_length - postfix_length, 0 + prompt_length - cache_length - input_length, 0 ) if remaining_prefill_tokens > 0: next_chunk_length = max( @@ -1842,8 +1834,8 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( batch.prompt_lengths, - batch.prefix_lengths, - batch.postfix_lengths, + batch.cache_lengths, + batch.input_lengths, batch.all_input_ids, accepted_ids, ) @@ -1858,14 +1850,14 @@ class FlashCausalLM(Model): cumulative_length = 0 for i, ( prompt_length, - prefix_length, - postfix_length, + cache_length, + input_length, all_input_ids, n_accepted_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length - end_index = cumulative_length + postfix_length + end_index = cumulative_length + input_length if prefill: # Indexing metadata @@ -1899,17 +1891,17 @@ class FlashCausalLM(Model): # Represent whether this request is still prefilling # If it is, the tokens we decoded should be ignored - accept_tokens = prefix_length + postfix_length >= prompt_length + accept_tokens = cache_length + input_length >= prompt_length if accept_tokens: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): - batch.all_input_ids_tensor[ - i, prefix_length + postfix_length + j - ] = next_input_ids[index] + batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( + next_input_ids[index] + ) index += 1 - cumulative_length += postfix_length + cumulative_length += input_length # Update values # These values can be updated without a GPU -> CPU sync @@ -1917,8 +1909,8 @@ class FlashCausalLM(Model): batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.prefix_lengths_tensor += batch.postfix_lengths_tensor - batch.postfix_lengths_tensor = accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1959,24 +1951,24 @@ class FlashCausalLM(Model): request_prefilling, next_token_id, all_input_ids, - prefix_length, - postfix_length, + cache_length, + input_length, next_chunk_length, ) in enumerate( zip( batch.prefilling_mask, next_token_ids, batch.all_input_ids, - batch.prefix_lengths, - batch.postfix_lengths, + batch.cache_lengths, + batch.input_lengths, next_chunk_lengths, ) ): if request_prefilling: - next_prefix_length = prefix_length + postfix_length + next_cache_length = cache_length + input_length # Get new prompt IDs to prefill postfix_ids = all_input_ids[ - next_prefix_length : next_prefix_length + next_chunk_length + next_cache_length : next_cache_length + next_chunk_length ] else: # This request is done prefilling, the new id is the one selected the sampling method @@ -1996,8 +1988,8 @@ class FlashCausalLM(Model): iterator = zip( batch.requests, batch.prompt_lengths, - batch.prefix_lengths, - batch.postfix_lengths, + batch.cache_lengths, + batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, @@ -2012,15 +2004,15 @@ class FlashCausalLM(Model): batch_top_token_logprobs, ) - # Reset max_postfix_length - batch.max_postfix_length = 0 + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch index = 0 for i, ( request, prompt_length, - prefix_length, - postfix_length, + cache_length, + input_length, prefix_offset, read_offset, stopping_criteria, @@ -2084,9 +2076,9 @@ class FlashCausalLM(Model): # Make sure that we do not stop as even though this request did not create a token, it is still # processing stopped = False - new_postfix_length = next_chunk_lengths[i] + new_input_length = next_chunk_lengths[i] else: - new_postfix_length = n_accepted_ids + new_input_length = n_accepted_ids # Append next token to all tokens next_token_texts = [] left = 0 @@ -2198,14 +2190,12 @@ class FlashCausalLM(Model): ) # Update values - current_prefix_length = prefix_length + postfix_length - batch.prefix_lengths[i] = current_prefix_length - current_postfix_length = new_postfix_length - batch.max_postfix_length = max( - batch.max_postfix_length, current_postfix_length - ) - batch.postfix_lengths[i] = current_postfix_length - current_length = current_prefix_length + current_postfix_length + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset @@ -2235,8 +2225,8 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - postfix_lengths_tensor: torch.Tensor, - prefix_lengths_tensor: torch.Tensor, + input_lengths_tensor: torch.Tensor, + cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": @@ -2247,7 +2237,7 @@ class FlashCausalLM(Model): use_prefill_with_paged_kv_state, ) - # has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths) + # has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths) if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( @@ -2256,12 +2246,12 @@ class FlashCausalLM(Model): ), # block_tables=block_tables_to_ragged( # block_tables=block_tables, - # postfix_lengths=postfix_lengths, - # prefix_lengths=prefix_lengths, + # input_lengths=input_lengths, + # cache_lengths=cache_lengths, # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -2270,10 +2260,10 @@ class FlashCausalLM(Model): window_left=self.sliding_window, ) else: - assert postfix_lengths_tensor is not None + assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, @@ -2285,21 +2275,19 @@ class FlashCausalLM(Model): def block_tables_to_ragged( - *, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int] + *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" - assert len(postfix_lengths) == len(prefix_lengths) + assert len(input_lengths) == len(cache_lengths) - total_len = sum(postfix_lengths) + sum(prefix_lengths) + total_len = sum(input_lengths) + sum(cache_lengths) block_tables_ragged = torch.empty( total_len, dtype=torch.int32, device=block_tables.device ) offset = 0 - for i, (input_length, prefix_length) in enumerate( - zip(postfix_lengths, prefix_lengths) - ): - seq_len = prefix_length + input_length + for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): + seq_len = cache_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 9e19e171..3aa475c3 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -285,7 +285,7 @@ class MllamaCausalLM(VlmCausalLM): max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=prefix_lens_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7484e448..a06add13 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor + input_lengths = batch.input_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM): position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - postfix_lengths = ( - postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lengths_tensor = ( - batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - postfix_lengths = batch.postfix_lengths_tensor - prefix_lengths_tensor = batch.prefix_lengths_tensor + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -359,19 +359,19 @@ class VlmCausalLM(FlashCausalLM): if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - postfix_lengths_tensor=postfix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (postfix_lengths + prefix_lengths_tensor).max().item() + max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( - postfix_lengths=postfix_lengths, - prefix_lengths=prefix_lengths_tensor, + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -408,8 +408,8 @@ class VlmCausalLM(FlashCausalLM): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - postfix_lengths=batch.postfix_lengths, - prefix_lengths=batch.prefix_lengths, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: @@ -418,18 +418,18 @@ class VlmCausalLM(FlashCausalLM): ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["postfix_lengths"].zero_() - cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][ - : prefix_lengths_tensor.shape[0] - ] = prefix_lengths_tensor + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - postfix_lengths_tensor=cuda_graph["postfix_lengths"], - prefix_lengths_tensor=cuda_graph["prefix_lengths"], + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], state=cuda_graph["state"], ): # Replay the graph