TMP chunking.

This commit is contained in:
Nicolas Patry 2024-09-02 11:46:36 +02:00
parent 38fcafcf96
commit 2f0fde1055
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
11 changed files with 77 additions and 30 deletions

View File

@ -159,6 +159,7 @@ impl Client {
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, prefix_len: 0,
suffix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -246,6 +246,7 @@ impl Health for ShardedClient {
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, prefix_len: 0,
suffix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -113,16 +113,23 @@ impl Client {
max_total_tokens: u32, max_total_tokens: u32,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let mut n_tokens = 0; let mut rest_tokens = max_prefill_tokens;
let mut requests = Vec::new(); let mut requests = Vec::new();
let max_tokens_per_request = core::cmp::min(max_input_length, max_prefill_tokens);
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while rest_tokens > 0 {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let curr_tokens = min(max_tokens_per_request, rest_tokens);
let truncate = min(max_input_length, rest_tokens);
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
let suffix_len = 0;
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
input_chunks input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 { let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if rest_tokens == max_prefill_tokens {
input_chunks.push( input_chunks.push(
Chunk::Image(Image { Chunk::Image(Image {
// Safe unwrap, because we control the data. // Safe unwrap, because we control the data.
@ -131,14 +138,6 @@ impl Client {
}) })
.into(), .into(),
); );
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads. // 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation. // Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!( inputs.push_str(&format!(
@ -158,7 +157,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, prefix_len,
suffix_len,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -182,7 +182,7 @@ impl Client {
top_n_tokens: 20, top_n_tokens: 20,
adapter_id: None, adapter_id: None,
}); });
n_tokens += max_input_length; rest_tokens -= curr_tokens;
// Check max_batch_size // Check max_batch_size
if Some(requests.len()) == max_batch_size { if Some(requests.len()) == max_batch_size {

View File

@ -247,6 +247,7 @@ impl Health for ShardedClient {
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, prefix_len: 0,
suffix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -131,9 +131,9 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens { // if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); // return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
} // }
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(

View File

@ -278,7 +278,7 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate; let total_tokens = prefill_tokens + decode_tokens + self.speculate;
if prefill_tokens > prefill_token_budget || total_tokens > token_budget { if total_tokens > token_budget {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
@ -298,9 +298,7 @@ impl State {
}; };
decode_tokens += max_new_tokens; decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
@ -392,6 +390,7 @@ impl State {
block_allocation.prefix_len, block_allocation.prefix_len,
), ),
}; };
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
entry.block_allocation = block_allocation; entry.block_allocation = block_allocation;
@ -428,6 +427,7 @@ impl State {
blocks, blocks,
slots, slots,
prefix_len, prefix_len,
suffix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time

View File

@ -159,6 +159,7 @@ async fn prefill(
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, prefix_len: 0,
suffix_len: 0,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();

View File

@ -0,0 +1,38 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_handle(launcher):
with launcher(
"huggingface/llama-7b", num_shard=2, max_batch_prefill_tokens=2
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama(flash_llama_handle):
await flash_llama_handle.health(300)
return flash_llama_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot):
response = await flash_llama.generate("What is Deep Learning ?", max_new_tokens=10)
assert response.details.generated_text == "xx"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_load(flash_llama, generate_load, response_snapshot):
responses = await generate_load(
flash_llama, "What is Deep Learning ?", max_new_tokens=10, n=4
)
assert responses[0].details.generated_text == "xx"
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1675,12 +1675,12 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(), "`max_input_tokens must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens { // if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!( // return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", // "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens // max_batch_prefill_tokens, max_input_tokens
))); // )));
} // }
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");

View File

@ -137,8 +137,10 @@ message Request {
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache. /// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12; uint32 prefix_len = 12;
/// Part of the query that needs to not be computed right away.
uint32 suffix_len = 13;
/// Context truncation /// Context truncation
bool add_special_tokens = 13; bool add_special_tokens = 14;
} }
message Batch { message Batch {

View File

@ -269,6 +269,9 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len prefix_len = r.prefix_len
import ipdb
ipdb.set_trace()
assert ( assert (
prefix_len <= orig_input_length prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}" ), f"Prefix {prefix_len} vs input {orig_input_length}"
@ -318,11 +321,11 @@ class FlashCausalLMBatch(Batch):
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks. # Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length block_tokens = orig_input_length + max_new_tokens + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the # Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present). # cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length slot_tokens = input_length + max_new_tokens + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks: