From 630e417ca0a5df5ad7bf19150ccebb89a06167f9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:20:20 +0200 Subject: [PATCH] add window size in proto --- proto/generate.proto | 1 + router/src/infer.rs | 3 +- router/src/queue.rs | 31 ++++++++++++++----- router/src/server.rs | 1 + server/Makefile-vllm | 2 +- .../models/flash_causal_lm.py | 8 ++--- .../models/flash_mistral.py | 2 +- server/text_generation_server/models/model.py | 6 ++++ 8 files changed, 40 insertions(+), 14 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 3f607dc5..c873e661 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -31,6 +31,7 @@ message InfoResponse { bool requires_padding = 1; string dtype = 2; string device_type = 3; + optional uint32 window_size = 4; } /// Empty request diff --git a/router/src/infer.rs b/router/src/infer.rs index 67b5bde2..787ccfcf 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -50,10 +50,11 @@ impl Infer { max_waiting_tokens: usize, max_concurrent_requests: usize, requires_padding: bool, + window_size: Option, generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16); + let queue = Queue::new(requires_padding, 16, window_size); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/queue.rs b/router/src/queue.rs index e97a168e..be253d69 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -2,6 +2,7 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; +use std::cmp::min; use std::collections::VecDeque; use text_generation_client::{Batch, Request}; use tokio::sync::oneshot; @@ -33,12 +34,17 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { + pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); + tokio::spawn(queue_task( + requires_padding, + block_size, + window_size, + queue_receiver, + )); Self { queue_sender } } @@ -84,9 +90,10 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + window_size: Option, receiver: flume::Receiver, ) { - let mut state = State::new(requires_padding, block_size); + let mut state = State::new(requires_padding, block_size, window_size); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -126,16 +133,20 @@ struct State { /// Paged Attention block size block_size: u32, + + /// Sliding window + window_size: Option, } impl State { - fn new(requires_padding: bool, block_size: u32) -> Self { + fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, block_size, + window_size, } } @@ -204,11 +215,17 @@ impl State { if self.requires_padding { decode_tokens += entry.request.stopping_parameters.max_new_tokens; } else { + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; + // pad to block size decode_tokens += - ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) - / self.block_size) - * self.block_size; + ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; } if prefill_tokens > prefill_token_budget diff --git a/router/src/server.rs b/router/src/server.rs index fbc444fc..f254afd8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -595,6 +595,7 @@ pub async fn run( max_waiting_tokens, max_concurrent_requests, shard_info.requires_padding, + shard_info.window_size, generation_health, ); diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 96bfc108..2e965da0 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,4 +1,4 @@ -vllm_commit := e86af624d059969b0fb07b075b1d338bf10c3365 +vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78 vllm: # Clone vllm diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cefa32d8..1fe40c0c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -636,12 +636,11 @@ class FlashCausalLM(Model): device: torch.device, rank: int = 0, world_size: int = 1, - repeat_slots: bool = False, + sliding_window: Optional[int] = None, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads self.head_size = head_size - self.repeat_slots = repeat_slots super(FlashCausalLM, self).__init__( model=model, @@ -651,6 +650,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, + sliding_window=sliding_window, ) @property @@ -665,7 +665,7 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.repeat_slots, + self.sliding_window is not None, self.dtype, self.device, ) @@ -705,7 +705,7 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.repeat_slots, + self.sliding_window is not None, self.dtype, self.device, ) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 266ae8dd..919e4625 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -331,7 +331,7 @@ class FlashMistral(FlashCausalLM): device=device, rank=rank, world_size=world_size, - repeat_slots=True, + sliding_window=config.sliding_window, ) @property diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f6e66d30..17d2ea9b 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -21,6 +21,7 @@ class Model(ABC): device: torch.device, rank: int = 0, world_size: int = 1, + sliding_window: Optional[int] = None, ): self.model = model.eval() self.tokenizer = tokenizer @@ -30,6 +31,7 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size + self.sliding_window = sliding_window self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -40,10 +42,14 @@ class Model(ABC): @property def info(self) -> InfoResponse: + if self.requires_padding and self.sliding_window is not None: + raise NotImplementedError("sliding_window is not implemented with padding") + return InfoResponse( requires_padding=self.requires_padding, dtype=str(self.dtype), device_type=self.device.type, + window_size=self.sliding_window, ) @property