diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 0bb5dc0c..2e2e9a11 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -179,18 +179,6 @@ fn main() -> Result<(), Box> { .await .expect("Unable to clear cache"); - // Warmup shard - let max_batch_size = batch_size.iter().max().unwrap(); - sharded_client - .warmup( - sequence_length, - sequence_length * max_batch_size, - (sequence_length + decode_length) * max_batch_size, - Some(*max_batch_size as usize), - ) - .await - .expect("Unable to warmup"); - tracing::info!("Connected"); // Run app diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 214adcdc..2389ac63 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1727,12 +1727,6 @@ fn main() -> Result<(), LauncherError> { "`max_input_tokens must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_input_tokens - ))); - } if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); @@ -1786,12 +1780,6 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_batch_total_tokens - ))); - } if max_total_tokens as u32 > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4e9f9c66..8d33a2b3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -173,9 +173,6 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward prefill_logprob_tokens: List[Optional[Tokens]] - # Prefixes - prefix_ids: List[List[int]] - # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -259,7 +256,6 @@ class FlashCausalLMBatch(Batch): read_offsets = [] all_input_ids = [] all_postfix_ids = [] - prefix_ids = [] requests_idx_mapping = {} next_token_chooser_parameters = [] @@ -297,7 +293,6 @@ class FlashCausalLMBatch(Batch): assert get_support_chunking() assert input_length > 0 - prefix_ids.append(tokenized_input[:cache_length]) postfix_ids = tokenized_input[cache_length : cache_length + input_length] assert ( @@ -400,7 +395,6 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -464,7 +458,6 @@ class FlashCausalLMBatch(Batch): requests = [] block_tables = [] all_input_ids = [] - prefix_ids = [] input_ids = [] prompt_lengths = [] @@ -505,7 +498,6 @@ class FlashCausalLMBatch(Batch): ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) @@ -621,7 +613,6 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -718,7 +709,6 @@ class FlashCausalLMBatch(Batch): block_tables = [] cache_lengths = [] all_input_ids = [] - prefix_ids = [] prompt_lengths = [] input_lengths = [] @@ -802,7 +792,6 @@ class FlashCausalLMBatch(Batch): block_tables.extend(batch.block_tables) cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) @@ -873,7 +862,6 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -1839,6 +1827,8 @@ class FlashCausalLM(Model): batch.input_lengths, batch.all_input_ids, accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second @@ -1855,6 +1845,8 @@ class FlashCausalLM(Model): input_length, all_input_ids, n_accepted_ids, + request_was_prefilling, + request_is_prefilling, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -1864,7 +1856,6 @@ class FlashCausalLM(Model): # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index if finished_prefilling: # Initialize position_ids @@ -1880,21 +1871,25 @@ class FlashCausalLM(Model): # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: + # If the request was prefilling and cache_length == 0, the first token is a bogus token + # and needs to be removed. We do so by incrementing the start_index + if request_was_prefilling and cache_length == 0: + start_index += 1 + + # If the request was prefilling, and it is done prefilling, the last token was generated and is + # therefore not part of the prefill. We remove it by decrementing out_end_index + if request_was_prefilling and not request_is_prefilling: + out_end_index -= 1 + if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = ( - batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices[out_start_index:out_end_index] = ( + batch.input_ids[start_index:end_index] ) else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] + prefill_tokens_indices = batch.input_ids[start_index:end_index] - # Represent whether this request is still prefilling - # If it is, the tokens we decoded should be ignored - accept_tokens = cache_length + input_length >= prompt_length - - if accept_tokens: + if not request_is_prefilling: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( @@ -1995,7 +1990,6 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -2019,7 +2013,6 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, top_n_tokens, @@ -2039,19 +2032,30 @@ class FlashCausalLM(Model): out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] + # log_master(logger.info, f"{prefill_logprobs}") + + if not request_is_prefilling: + # If the request is done prefilling, then the last logprob is a generated token + # We need to remove it + out_end_index -= 1 + request_prefill_logprobs = prefill_logprobs[ - out_start_index : out_end_index - 1 + out_start_index:out_end_index + ] + prefill_token_ids = all_input_ids[ + cache_length : cache_length + input_length ] - prefill_token_ids = all_input_ids[:-1] past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] if past_prefill_logprob_tokens is None: # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] * ( - len(prefix_ids) + 1 + cache_length + 1 ) + request_prefill_logprobs - prefill_token_ids = prefix_ids + prefill_token_ids + prefill_token_ids = ( + all_input_ids[:cache_length] + prefill_token_ids + ) prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -2059,6 +2063,10 @@ class FlashCausalLM(Model): skip_special_tokens=False, ) + # log_master(logger.info, f"{prefill_token_ids}") + # log_master(logger.info, f"{request_prefill_logprobs}") + # log_master(logger.info, f"{prefill_texts}") + prefill_logprob_tokens = Tokens( prefill_token_ids, request_prefill_logprobs,