From 50d5c08b15e6bd5ffcd73ef1233abbbb1ce04437 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 May 2024 15:37:46 +0000 Subject: [PATCH] Router logic knows about page size. Missing 2 models. --- router/src/infer.rs | 8 +++++++- .../models/custom_modeling/flash_gpt2_modeling.py | 1 + .../models/custom_modeling/flash_phi_modeling.py | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 1447e756..cc28e3af 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -71,7 +71,13 @@ impl Infer { processor_config: HubProcessorConfig, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16, window_size, speculate); + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 32 }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d2599f7a..85bfafd2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -240,6 +240,7 @@ class FlashGPT2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f2efb538..455f5771 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -213,6 +213,7 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, )