Router logic knows about page size.

Missing 2 models.
This commit is contained in:
Nicolas Patry 2024-05-29 15:37:46 +00:00
parent 7a29e82629
commit 50d5c08b15
3 changed files with 9 additions and 1 deletions

View File

@ -71,7 +71,13 @@ impl Infer {
processor_config: HubProcessorConfig, processor_config: HubProcessorConfig,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate); let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 32 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });

View File

@ -240,6 +240,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
None,
input_lengths, input_lengths,
max_s, max_s,
) )

View File

@ -213,6 +213,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
None,
input_lengths, input_lengths,
max_s, max_s,
) )