Router logic knows about page size.

Missing 2 models.
This commit is contained in:
Nicolas Patry 2024-05-29 15:37:46 +00:00
parent 7a29e82629
commit 50d5c08b15
3 changed files with 9 additions and 1 deletions

View File

@ -71,7 +71,13 @@ impl Infer {
processor_config: HubProcessorConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 32 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});

View File

@ -240,6 +240,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -213,6 +213,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)