diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index f8a10ca2..77fdb041 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -36,9 +36,9 @@ impl BackendV3 { speculate: u32, ) -> Self { let prefix_caching = - std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); - let attention: String = std::env::var("ATTENTION").expect("attention env var"); + let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); let attention: Attention = attention .parse() diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d99771f8..1efeba58 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -2,10 +2,6 @@ import pytest import os from text_generation_server.pb import generate_pb2 -os.environ["USE_PREFIX_CACHING"] = "1" -os.environ["ATTENTION"] = "flashinfer" - - @pytest.fixture def default_pb_parameters(): return generate_pb2.NextTokenChooserParameters( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8933f13e..33fe30a8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -149,26 +149,11 @@ class FlashCausalLMBatch(Batch): max_seqlen: int - # Prefill metadata tensors + # Prefill metadata tensors to efficiently compute logprobs prefill_head_indices: Optional[torch.Tensor] - prefill_next_token_indices: Optional[torch.Tensor] + prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] - # Whether at least one request is prefilling/chunking - # == any(prefilling_mask) - prefilling: bool - # For each request, whether they are still prefilling/chunking - prefilling_mask: List[bool] - # For each request, whether the model output should be used or discarded - # If we are chunking, we don't care about the output as it might be different - # from the token in the prompt - use_output_token: List[bool] - - # If the request is decoding, `next_chunk_length = 1` - # `None if not batch.prefilling` - next_chunk_lengths: Optional[List[int]] - next_chunk_lengths_tensor: Optional[torch.Tensor] - # Prefixes prefix_ids: List[List[int]] @@ -247,14 +232,11 @@ class FlashCausalLMBatch(Batch): prefix_ids = [] requests_idx_mapping = {} - chunking = False all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] - next_chunk_lengths = [] - use_output_token = [] next_token_chooser_parameters = [] stopping_criterias = [] @@ -294,7 +276,6 @@ class FlashCausalLMBatch(Batch): assert prefix_len > 0 prefix_len -= 1 - # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_len]) @@ -303,18 +284,9 @@ class FlashCausalLMBatch(Batch): input_length = len(tokenized_input) input_lengths.append(input_length) - if True: - # This request only requires one prefill and no chunking - use_output_token.append(True) - next_chunk_lengths.append(1) - else: - chunking = True - raise NotImplementedError - prefix_offsets.append(input_length - 5) read_offsets.append(input_length) - # FIXME: use all input tokens not just postfix ones all_input_ids.append(tokenized_input) # Position ids @@ -385,7 +357,6 @@ class FlashCausalLMBatch(Batch): # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: - raise NotImplementedError request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - sliding_window), cumulative_length + input_length, @@ -397,7 +368,6 @@ class FlashCausalLMBatch(Batch): no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: - raise NotImplementedError prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 @@ -475,12 +445,6 @@ class FlashCausalLMBatch(Batch): adapter_segments, dtype=torch.int32, device=device ) - if chunking: - next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device) - else: - next_chunk_lengths = None - next_chunk_lengths_tensor = None - if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -527,11 +491,6 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - prefilling=True, - prefilling_mask=[True] * pb.requests.len(), - use_output_token=use_output_token, - next_chunk_lengths=next_chunk_lengths, - next_chunk_lengths_tensor=next_chunk_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, @@ -1467,7 +1426,7 @@ class FlashCausalLM(Model): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - if not batch.prefilling and self.max_past() is not None: + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. @@ -1481,7 +1440,7 @@ class FlashCausalLM(Model): else: cuda_graph = None - if batch.prefilling or cuda_graph is None: + if cu_seqlen_prefill is not None or cuda_graph is None: if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1516,7 +1475,6 @@ class FlashCausalLM(Model): adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: - raise NotImplementedError batch.prefill_cache_indices = None return logits, speculative_logits @@ -1570,6 +1528,7 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() + prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1595,13 +1554,13 @@ class FlashCausalLM(Model): adapter_data = AdapterBatchData.from_meta( adapter_meta, self.layer_to_adapter_weights, - batch.prefilling, + prefill, batch.prefill_head_indices, ) out, speculative_logits = self.forward(batch, adapter_data) - if batch.prefilling: + if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) @@ -1638,31 +1597,22 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if batch.prefilling: + if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - if batch.next_chunk_lengths is None: - # We are done prefilling after this forward - next_position_ids = batch.position_ids.new_empty(len(batch)) - # [BATCH_SIZE] - # Last slot for each request, will be incremented later - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - else: - # We still have prefill chunks to go through - next_forward_size = sum(batch.next_chunk_lengths) - next_position_ids = batch.position_ids.new_empty(next_forward_size) - batch.slot_indices = batch.slot_indices.new_empty(next_forward_size) - batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0) + next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + # We do not need cu_seqlen_prefill anymore + batch.cu_seqlen_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids # Cumulative length cumulative_length = 0 - cumulative_chunk_lengths = 0 # Results generations: List[Generation] = [] @@ -1675,32 +1625,21 @@ class FlashCausalLM(Model): # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time - index = 0 # For each member of the batch + index = 0 for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length - if batch.prefilling: - if batch.next_chunk_lengths is not None: - next_chunk_length = batch.next_chunk_lengths[i] - else: - next_chunk_length = 1 - + if prefill: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index - position_start_index = - # Initialize position_ids # In decode, we do not need this as we can just increment position ids - - - next_position_ids - next_position_ids[i] = batch.position_ids[end_index - 1] # Initialize adapter indices @@ -1712,7 +1651,6 @@ class FlashCausalLM(Model): # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: - raise NotImplementedError if len(batch) > 1: prefill_tokens_indices[out_start_index : out_end_index - 1] = ( batch.input_ids[start_index + 1 : start_index + out_length] @@ -1730,23 +1668,14 @@ class FlashCausalLM(Model): cumulative_length += input_length # Update values - if batch.next_prefilling_chunk_lengths is None: - # We are done prefilling - batch.prefilling = False - batch.next_prefilling_chunk_lengths = None - batch.next_prefilling_chunk_lengths_tensor = None - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices - if not batch.prefilling: - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if batch.prefilling: + if prefill: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1755,8 +1684,7 @@ class FlashCausalLM(Model): device=batch.adapter_meta.adapter_segments.device, ) - if batch.prefilling and prefill_logprobs: - raise NotImplementedError + if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -1867,8 +1795,7 @@ class FlashCausalLM(Model): generated_text = None # Prefill - if batch.prefilling and request.prefill_logprobs: - raise NotImplementedError + if prefill and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6c518c2c..1830dc42 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,9 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION") +ATTENTION = os.getenv("ATTENTION", "flashinfer") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected