mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
This seems to be working.
This commit is contained in:
parent
f5182c188c
commit
26e5037de4
@ -158,15 +158,17 @@ impl Allocator for RadixAllocator {
|
|||||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||||
let prefill_tokens = prefill_tokens.as_slice();
|
let prefill_tokens = prefill_tokens.as_slice();
|
||||||
|
|
||||||
assert_eq!(prefill_tokens.len() % self.block_size as usize, 0);
|
|
||||||
// If there are prefill tokens that did not come from the cache,
|
// If there are prefill tokens that did not come from the cache,
|
||||||
// add them to the cache.
|
// add them to the cache.
|
||||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||||
|
let aligned =
|
||||||
|
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||||
|
if aligned > 0 {
|
||||||
let prefix_len = self
|
let prefix_len = self
|
||||||
.cache_blocks
|
.cache_blocks
|
||||||
.insert(
|
.insert(
|
||||||
prefill_tokens,
|
&prefill_tokens[..aligned],
|
||||||
&blocks[..prefill_tokens.len() / self.block_size as usize],
|
&blocks[..aligned / self.block_size as usize],
|
||||||
)
|
)
|
||||||
// Unwrap, failing is a programming error.
|
// Unwrap, failing is a programming error.
|
||||||
.expect("Failed to store prefill tokens");
|
.expect("Failed to store prefill tokens");
|
||||||
@ -180,12 +182,16 @@ impl Allocator for RadixAllocator {
|
|||||||
// This means that while processing this request there was a
|
// This means that while processing this request there was a
|
||||||
// partially overlapping request that had A..=E in its
|
// partially overlapping request that had A..=E in its
|
||||||
// prefill. In this case we need to free the blocks D E.
|
// prefill. In this case we need to free the blocks D E.
|
||||||
self.free_blocks
|
self.free_blocks.extend(
|
||||||
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
|
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||||
|
..prefix_len / self.block_size as usize],
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free non-prefill blocks.
|
// Free non-prefill blocks.
|
||||||
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
|
self.free_blocks
|
||||||
|
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||||
} else {
|
} else {
|
||||||
self.free_blocks.extend(blocks);
|
self.free_blocks.extend(blocks);
|
||||||
}
|
}
|
||||||
@ -605,6 +611,24 @@ mod tests {
|
|||||||
assert_eq!(allocation.prefix_len, 4);
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_block_size_non_aligned() {
|
||||||
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
|
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(
|
||||||
|
allocation.blocks[..allocation.blocks.len() - 1].to_vec(),
|
||||||
|
allocation.allocation_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
|
||||||
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allocator_reuses_prefixes() {
|
fn allocator_reuses_prefixes() {
|
||||||
let mut cache = RadixAllocator::new(1, 12, None);
|
let mut cache = RadixAllocator::new(1, 12, None);
|
||||||
|
@ -83,7 +83,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
"Forcing flash decoding because model {} requires it",
|
"Forcing flash decoding because model {} requires it",
|
||||||
config.model_type.as_ref().unwrap()
|
config.model_type.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("1".to_string());
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some("flashdecoding".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,7 +93,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
_ => {
|
_ => {
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("1".to_string());
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some("flashdecoding".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1000,7 +1000,7 @@ impl TryFrom<&[u8]> for PythonLogMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||||
let mut buffer = vec![0u8; 4096];
|
let mut buffer = vec![0u8; 8 * 4096];
|
||||||
let mut stdout = std::io::stdout();
|
let mut stdout = std::io::stdout();
|
||||||
loop {
|
loop {
|
||||||
let n = bufread.read(&mut buffer);
|
let n = bufread.read(&mut buffer);
|
||||||
|
@ -265,13 +265,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
orig_input_length = len(tokenized_input)
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
prefix_len = r.prefix_len
|
prefix_len = r.prefix_len
|
||||||
if prefix_len == orig_input_length:
|
if prefix_len == orig_input_length:
|
||||||
assert prefix_len > 0
|
assert prefix_len > 0
|
||||||
prefix_len -= 1
|
prefix_len -= 1
|
||||||
else:
|
|
||||||
prefix_len = 0
|
|
||||||
|
|
||||||
prefix_ids.append(tokenized_input[:prefix_len])
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
tokenized_input = tokenized_input[prefix_len:]
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
@ -14,8 +14,8 @@ assert (
|
|||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
if PREFIX_CACHING and ATTENTION != "flashinfer":
|
# if PREFIX_CACHING and ATTENTION != "flashinfer":
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
# raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user