mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
TMP chunking.
This commit is contained in:
parent
38fcafcf96
commit
2f0fde1055
@ -159,6 +159,7 @@ impl Client {
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
suffix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -246,6 +246,7 @@ impl Health for ShardedClient {
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
suffix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
@ -113,16 +113,23 @@ impl Client {
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut rest_tokens = max_prefill_tokens;
|
||||
let mut requests = Vec::new();
|
||||
|
||||
let max_tokens_per_request = core::cmp::min(max_input_length, max_prefill_tokens);
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
while rest_tokens > 0 {
|
||||
let curr_tokens = min(max_tokens_per_request, rest_tokens);
|
||||
let truncate = min(max_input_length, rest_tokens);
|
||||
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
|
||||
let suffix_len = 0;
|
||||
|
||||
let mut input_chunks = Vec::new();
|
||||
input_chunks
|
||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||
if n_tokens == 0 {
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if rest_tokens == max_prefill_tokens {
|
||||
input_chunks.push(
|
||||
Chunk::Image(Image {
|
||||
// Safe unwrap, because we control the data.
|
||||
@ -131,14 +138,6 @@ impl Client {
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||
// been updated to support chunks.
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
@ -158,7 +157,8 @@ impl Client {
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
prefix_len,
|
||||
suffix_len,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
@ -182,7 +182,7 @@ impl Client {
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
rest_tokens -= curr_tokens;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
|
@ -247,6 +247,7 @@ impl Health for ShardedClient {
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
suffix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
@ -131,9 +131,9 @@ async fn main() -> Result<(), RouterError> {
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
// if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
// return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
// }
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
|
@ -278,7 +278,7 @@ impl State {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||
if total_tokens > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
@ -298,9 +298,7 @@ impl State {
|
||||
};
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
@ -392,6 +390,7 @@ impl State {
|
||||
block_allocation.prefix_len,
|
||||
),
|
||||
};
|
||||
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
|
||||
|
||||
entry.block_allocation = block_allocation;
|
||||
|
||||
@ -428,6 +427,7 @@ impl State {
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
suffix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
|
@ -159,6 +159,7 @@ async fn prefill(
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
suffix_len: 0,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
38
integration-tests/models/test_flash_llama_chunking.py
Normal file
38
integration-tests/models/test_flash_llama_chunking.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_handle(launcher):
|
||||
with launcher(
|
||||
"huggingface/llama-7b", num_shard=2, max_batch_prefill_tokens=2
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama(flash_llama_handle):
|
||||
await flash_llama_handle.health(300)
|
||||
return flash_llama_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama(flash_llama, response_snapshot):
|
||||
response = await flash_llama.generate("What is Deep Learning ?", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_text == "xx"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_load(flash_llama, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_llama, "What is Deep Learning ?", max_new_tokens=10, n=4
|
||||
)
|
||||
assert responses[0].details.generated_text == "xx"
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
@ -1675,12 +1675,12 @@ fn main() -> Result<(), LauncherError> {
|
||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
// if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
// return Err(LauncherError::ArgumentValidation(format!(
|
||||
// "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
// max_batch_prefill_tokens, max_input_tokens
|
||||
// )));
|
||||
// }
|
||||
|
||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||
|
@ -137,8 +137,10 @@ message Request {
|
||||
optional string adapter_id = 11;
|
||||
/// Prefix length that can be retrieved from the KV cache.
|
||||
uint32 prefix_len = 12;
|
||||
/// Part of the query that needs to not be computed right away.
|
||||
uint32 suffix_len = 13;
|
||||
/// Context truncation
|
||||
bool add_special_tokens = 13;
|
||||
bool add_special_tokens = 14;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -269,6 +269,9 @@ class FlashCausalLMBatch(Batch):
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
prefix_len = r.prefix_len
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
assert (
|
||||
prefix_len <= orig_input_length
|
||||
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
||||
@ -318,11 +321,11 @@ class FlashCausalLMBatch(Batch):
|
||||
speculative_length = 0 if speculative_length is None else speculative_length
|
||||
|
||||
# Tokens that need to be mapped to blocks.
|
||||
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
||||
block_tokens = orig_input_length + max_new_tokens + speculative_length
|
||||
|
||||
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||
# cached prefix (if present).
|
||||
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
slot_tokens = input_length + max_new_tokens + speculative_length
|
||||
|
||||
# blocks and slots can be empty (for example in warmup)
|
||||
if not r.blocks:
|
||||
|
Loading…
Reference in New Issue
Block a user