TMP chunking.

This commit is contained in:
Nicolas Patry 2024-09-02 11:46:36 +02:00
parent 38fcafcf96
commit 2f0fde1055
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
11 changed files with 77 additions and 30 deletions

View File

@ -159,6 +159,7 @@ impl Client {
blocks: vec![],
slots: vec![],
prefix_len: 0,
suffix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,

View File

@ -246,6 +246,7 @@ impl Health for ShardedClient {
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
suffix_len: 0,
adapter_id: None,
};
let batch = Batch {

View File

@ -113,16 +113,23 @@ impl Client {
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut rest_tokens = max_prefill_tokens;
let mut requests = Vec::new();
let max_tokens_per_request = core::cmp::min(max_input_length, max_prefill_tokens);
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
while rest_tokens > 0 {
let curr_tokens = min(max_tokens_per_request, rest_tokens);
let truncate = min(max_input_length, rest_tokens);
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
let suffix_len = 0;
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if rest_tokens == max_prefill_tokens {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
@ -131,14 +138,6 @@ impl Client {
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
@ -158,7 +157,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
prefix_len,
suffix_len,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
@ -182,7 +182,7 @@ impl Client {
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
rest_tokens -= curr_tokens;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {

View File

@ -247,6 +247,7 @@ impl Health for ShardedClient {
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
suffix_len: 0,
adapter_id: None,
};
let batch = Batch {

View File

@ -131,9 +131,9 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
// if max_input_tokens as u32 > max_batch_prefill_tokens {
// return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
// }
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(

View File

@ -278,7 +278,7 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
if total_tokens > token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
@ -298,9 +298,7 @@ impl State {
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
@ -392,6 +390,7 @@ impl State {
block_allocation.prefix_len,
),
};
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
entry.block_allocation = block_allocation;
@ -428,6 +427,7 @@ impl State {
blocks,
slots,
prefix_len,
suffix_len,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time

View File

@ -159,6 +159,7 @@ async fn prefill(
blocks: vec![],
slots: vec![],
prefix_len: 0,
suffix_len: 0,
adapter_id: None,
})
.collect();

View File

@ -0,0 +1,38 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_handle(launcher):
with launcher(
"huggingface/llama-7b", num_shard=2, max_batch_prefill_tokens=2
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama(flash_llama_handle):
await flash_llama_handle.health(300)
return flash_llama_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot):
response = await flash_llama.generate("What is Deep Learning ?", max_new_tokens=10)
assert response.details.generated_text == "xx"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_load(flash_llama, generate_load, response_snapshot):
responses = await generate_load(
flash_llama, "What is Deep Learning ?", max_new_tokens=10, n=4
)
assert responses[0].details.generated_text == "xx"
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1675,12 +1675,12 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
// if max_input_tokens as u32 > max_batch_prefill_tokens {
// return Err(LauncherError::ArgumentValidation(format!(
// "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
// max_batch_prefill_tokens, max_input_tokens
// )));
// }
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");

View File

@ -137,8 +137,10 @@ message Request {
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Part of the query that needs to not be computed right away.
uint32 suffix_len = 13;
/// Context truncation
bool add_special_tokens = 13;
bool add_special_tokens = 14;
}
message Batch {

View File

@ -269,6 +269,9 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len
import ipdb
ipdb.set_trace()
assert (
prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}"
@ -318,11 +321,11 @@ class FlashCausalLMBatch(Batch):
speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
block_tokens = orig_input_length + max_new_tokens + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
slot_tokens = input_length + max_new_tokens + speculative_length
# blocks and slots can be empty (for example in warmup)
if not r.blocks: