mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
TMP chunking.
This commit is contained in:
parent
38fcafcf96
commit
2f0fde1055
@ -159,6 +159,7 @@ impl Client {
|
|||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
suffix_len: 0,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -246,6 +246,7 @@ impl Health for ShardedClient {
|
|||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
suffix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
@ -113,16 +113,23 @@ impl Client {
|
|||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut rest_tokens = max_prefill_tokens;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
|
let max_tokens_per_request = core::cmp::min(max_input_length, max_prefill_tokens);
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while rest_tokens > 0 {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let curr_tokens = min(max_tokens_per_request, rest_tokens);
|
||||||
|
let truncate = min(max_input_length, rest_tokens);
|
||||||
|
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
|
||||||
|
let suffix_len = 0;
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||||
if n_tokens == 0 {
|
let mut inputs = String::new();
|
||||||
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if rest_tokens == max_prefill_tokens {
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
// Safe unwrap, because we control the data.
|
// Safe unwrap, because we control the data.
|
||||||
@ -131,14 +138,6 @@ impl Client {
|
|||||||
})
|
})
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
|
||||||
// been updated to support chunks.
|
|
||||||
|
|
||||||
let mut inputs = String::new();
|
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
|
||||||
if n_tokens == 0 {
|
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
inputs.push_str(&format!(
|
inputs.push_str(&format!(
|
||||||
@ -158,7 +157,8 @@ impl Client {
|
|||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
prefix_len,
|
||||||
|
suffix_len,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
@ -182,7 +182,7 @@ impl Client {
|
|||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
rest_tokens -= curr_tokens;
|
||||||
|
|
||||||
// Check max_batch_size
|
// Check max_batch_size
|
||||||
if Some(requests.len()) == max_batch_size {
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
@ -247,6 +247,7 @@ impl Health for ShardedClient {
|
|||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
suffix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
@ -131,9 +131,9 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
// if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
// return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
}
|
// }
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
@ -278,7 +278,7 @@ impl State {
|
|||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
if total_tokens > token_budget {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
@ -298,9 +298,7 @@ impl State {
|
|||||||
};
|
};
|
||||||
decode_tokens += max_new_tokens;
|
decode_tokens += max_new_tokens;
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|
||||||
{
|
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
@ -392,6 +390,7 @@ impl State {
|
|||||||
block_allocation.prefix_len,
|
block_allocation.prefix_len,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
|
||||||
|
|
||||||
entry.block_allocation = block_allocation;
|
entry.block_allocation = block_allocation;
|
||||||
|
|
||||||
@ -428,6 +427,7 @@ impl State {
|
|||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
prefix_len,
|
prefix_len,
|
||||||
|
suffix_len,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
|
@ -159,6 +159,7 @@ async fn prefill(
|
|||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
suffix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
38
integration-tests/models/test_flash_llama_chunking.py
Normal file
38
integration-tests/models/test_flash_llama_chunking.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"huggingface/llama-7b", num_shard=2, max_batch_prefill_tokens=2
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama(flash_llama_handle):
|
||||||
|
await flash_llama_handle.health(300)
|
||||||
|
return flash_llama_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama(flash_llama, response_snapshot):
|
||||||
|
response = await flash_llama.generate("What is Deep Learning ?", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_text == "xx"
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_load(flash_llama, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama, "What is Deep Learning ?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
assert responses[0].details.generated_text == "xx"
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -1675,12 +1675,12 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
// if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
// return Err(LauncherError::ArgumentValidation(format!(
|
||||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
// "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||||
max_batch_prefill_tokens, max_input_tokens
|
// max_batch_prefill_tokens, max_input_tokens
|
||||||
)));
|
// )));
|
||||||
}
|
// }
|
||||||
|
|
||||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||||
|
@ -137,8 +137,10 @@ message Request {
|
|||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
/// Prefix length that can be retrieved from the KV cache.
|
/// Prefix length that can be retrieved from the KV cache.
|
||||||
uint32 prefix_len = 12;
|
uint32 prefix_len = 12;
|
||||||
|
/// Part of the query that needs to not be computed right away.
|
||||||
|
uint32 suffix_len = 13;
|
||||||
/// Context truncation
|
/// Context truncation
|
||||||
bool add_special_tokens = 13;
|
bool add_special_tokens = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -269,6 +269,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
orig_input_length = len(tokenized_input)
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
prefix_len = r.prefix_len
|
prefix_len = r.prefix_len
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
assert (
|
assert (
|
||||||
prefix_len <= orig_input_length
|
prefix_len <= orig_input_length
|
||||||
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
||||||
@ -318,11 +321,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
|
|
||||||
# Tokens that need to be mapped to blocks.
|
# Tokens that need to be mapped to blocks.
|
||||||
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
block_tokens = orig_input_length + max_new_tokens + speculative_length
|
||||||
|
|
||||||
# Tokens that need to be mapped to slots. We don't need slots for the
|
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||||
# cached prefix (if present).
|
# cached prefix (if present).
|
||||||
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
slot_tokens = input_length + max_new_tokens + speculative_length
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
|
Loading…
Reference in New Issue
Block a user