diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 479d31bf..58703c07 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -159,6 +159,7 @@ impl Client { blocks: vec![], slots: vec![], prefix_len: 0, + suffix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 645c076a..476e7a39 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -246,6 +246,7 @@ impl Health for ShardedClient { blocks: vec![0], slots: (0..16).collect(), prefix_len: 0, + suffix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 648662db..f0f21854 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -113,16 +113,23 @@ impl Client { max_total_tokens: u32, max_batch_size: Option, ) -> Result> { - let mut n_tokens = 0; + let mut rest_tokens = max_prefill_tokens; let mut requests = Vec::new(); + + let max_tokens_per_request = core::cmp::min(max_input_length, max_prefill_tokens); // Create requests - while n_tokens < max_prefill_tokens { - let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + while rest_tokens > 0 { + let curr_tokens = min(max_tokens_per_request, rest_tokens); + let truncate = min(max_input_length, rest_tokens); + let prefix_len = max_input_length.saturating_sub(max_prefill_tokens); + let suffix_len = 0; let mut input_chunks = Vec::new(); input_chunks .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); - if n_tokens == 0 { + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if rest_tokens == max_prefill_tokens { input_chunks.push( Chunk::Image(Image { // Safe unwrap, because we control the data. @@ -131,14 +138,6 @@ impl Client { }) .into(), ); - } - - // Send stringly-typed inputs for compatibility for backends that haven't - // been updated to support chunks. - - let mut inputs = String::new(); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); - if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. inputs.push_str(&format!( @@ -158,7 +157,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, + prefix_len, + suffix_len, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -182,7 +182,7 @@ impl Client { top_n_tokens: 20, adapter_id: None, }); - n_tokens += max_input_length; + rest_tokens -= curr_tokens; // Check max_batch_size if Some(requests.len()) == max_batch_size { diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index ea77a696..5fe7d37d 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -247,6 +247,7 @@ impl Health for ShardedClient { blocks: vec![0], slots: (0..16).collect(), prefix_len: 0, + suffix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 471ddb5a..56687075 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -131,9 +131,9 @@ async fn main() -> Result<(), RouterError> { "`max_input_tokens` must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); - } + // if max_input_tokens as u32 > max_batch_prefill_tokens { + // return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + // } if validation_workers == 0 { return Err(RouterError::ArgumentValidation( diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index f8123b57..ff9879f3 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -278,7 +278,7 @@ impl State { decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate; - if prefill_tokens > prefill_token_budget || total_tokens > token_budget { + if total_tokens > token_budget { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); @@ -298,9 +298,7 @@ impl State { }; decode_tokens += max_new_tokens; - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { + if (prefill_tokens + decode_tokens + self.speculate) > token_budget { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); @@ -392,6 +390,7 @@ impl State { block_allocation.prefix_len, ), }; + let suffix_len = (slots.len() as u32).saturating_sub(prefix_len); entry.block_allocation = block_allocation; @@ -428,6 +427,7 @@ impl State { blocks, slots, prefix_len, + suffix_len, adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 789c7b51..0080bb57 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -159,6 +159,7 @@ async fn prefill( blocks: vec![], slots: vec![], prefix_len: 0, + suffix_len: 0, adapter_id: None, }) .collect(); diff --git a/integration-tests/models/test_flash_llama_chunking.py b/integration-tests/models/test_flash_llama_chunking.py new file mode 100644 index 00000000..5ae7cc69 --- /dev/null +++ b/integration-tests/models/test_flash_llama_chunking.py @@ -0,0 +1,38 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_handle(launcher): + with launcher( + "huggingface/llama-7b", num_shard=2, max_batch_prefill_tokens=2 + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama(flash_llama_handle): + await flash_llama_handle.health(300) + return flash_llama_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama(flash_llama, response_snapshot): + response = await flash_llama.generate("What is Deep Learning ?", max_new_tokens=10) + + assert response.details.generated_text == "xx" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_load(flash_llama, generate_load, response_snapshot): + responses = await generate_load( + flash_llama, "What is Deep Learning ?", max_new_tokens=10, n=4 + ) + assert responses[0].details.generated_text == "xx" + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2cdccfe0..d93898ad 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1675,12 +1675,12 @@ fn main() -> Result<(), LauncherError> { "`max_input_tokens must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_input_tokens - ))); - } + // if max_input_tokens as u32 > max_batch_prefill_tokens { + // return Err(LauncherError::ArgumentValidation(format!( + // "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", + // max_batch_prefill_tokens, max_input_tokens + // ))); + // } if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 34894bda..9bba230a 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -137,8 +137,10 @@ message Request { optional string adapter_id = 11; /// Prefix length that can be retrieved from the KV cache. uint32 prefix_len = 12; + /// Part of the query that needs to not be computed right away. + uint32 suffix_len = 13; /// Context truncation - bool add_special_tokens = 13; + bool add_special_tokens = 14; } message Batch { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2834962..3844a990 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -269,6 +269,9 @@ class FlashCausalLMBatch(Batch): orig_input_length = len(tokenized_input) prefix_len = r.prefix_len + import ipdb + + ipdb.set_trace() assert ( prefix_len <= orig_input_length ), f"Prefix {prefix_len} vs input {orig_input_length}" @@ -318,11 +321,11 @@ class FlashCausalLMBatch(Batch): speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + block_tokens = orig_input_length + max_new_tokens + speculative_length # Tokens that need to be mapped to slots. We don't need slots for the # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + slot_tokens = input_length + max_new_tokens + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: