From 09839b05f4c3375d678ca70dcdac7e77ad1371dc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 5 Dec 2023 16:38:46 +0000 Subject: [PATCH] Fixing some simple stuff, adding `speculate` to budget. --- proto/generate.proto | 2 +- router/src/infer.rs | 3 +- router/src/queue.rs | 53 ++++++++++++++----- router/src/server.rs | 1 + .../models/flash_causal_lm.py | 16 ++---- server/text_generation_server/models/model.py | 7 +++ 6 files changed, 56 insertions(+), 26 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 659c62ff..19ec059b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -32,7 +32,7 @@ message InfoResponse { string dtype = 2; string device_type = 3; optional uint32 window_size = 4; - optional uint32 speculate = 5; + uint32 speculate = 5; } /// Empty request diff --git a/router/src/infer.rs b/router/src/infer.rs index dc5bbb01..de8debc3 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -50,10 +50,11 @@ impl Infer { max_concurrent_requests: usize, requires_padding: bool, window_size: Option, + speculate: u32, generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16, window_size); + let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/queue.rs b/router/src/queue.rs index bbb8db0e..0436b8f2 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -34,7 +34,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { + pub(crate) fn new( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -43,6 +48,7 @@ impl Queue { requires_padding, block_size, window_size, + speculate, queue_receiver, )); @@ -91,9 +97,10 @@ async fn queue_task( requires_padding: bool, block_size: u32, window_size: Option, + speculate: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, block_size, window_size); + let mut state = State::new(requires_padding, block_size, window_size, speculate); while let Some(cmd) = receiver.recv().await { match cmd { @@ -136,10 +143,18 @@ struct State { /// Sliding window window_size: Option, + + /// Speculation amount + speculate: u32, } impl State { - fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { + fn new( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + ) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, @@ -147,6 +162,7 @@ impl State { requires_padding, block_size, window_size, + speculate, } } @@ -221,7 +237,7 @@ impl State { window_size.saturating_sub(entry.request.input_length), entry.request.stopping_parameters.max_new_tokens, ), - }; + } + self.speculate; // pad to block size decode_tokens += @@ -359,7 +375,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -375,7 +391,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -383,7 +399,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -415,7 +431,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -448,14 +464,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -463,7 +479,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -496,7 +512,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -519,9 +535,20 @@ mod tests { assert_eq!(batch.size, 2); } + #[tokio::test] + async fn test_queue_next_batch_token_speculate() { + let queue = Queue::new(false, 1, None, 2); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + assert!(queue.next_batch(None, 1, 1).await.is_none()); + } + #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8..5f41fd5e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -596,6 +596,7 @@ pub async fn run( max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, + shard_info.speculate, generation_health, ); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 855061e5..2ca86488 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -938,8 +938,6 @@ class FlashCausalLM(Model): batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, - # next_token_ids, - # next_token_logprobs, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, @@ -957,8 +955,6 @@ class FlashCausalLM(Model): do_sample, seed, top_n_tokens, - # next_token_id, - # next_token_logprob, n_accepted_ids, top_token_ids, top_token_logprobs, @@ -968,21 +964,18 @@ class FlashCausalLM(Model): _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids] next_token_texts = [] + left = 0 for j in range(index, index + n_accepted_ids): # Generated token - all_input_ids.append(next_token_ids[j]) + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, prefix_offset, read_offset, ) next_token_texts.append(next_token_text) - index += n_accepted_ids - # Evaluate stopping criteria - - left = 0 - for j, next_token_id in enumerate(_next_token_ids): stop, reason = stopping_criteria( next_token_id, next_token_text, @@ -994,6 +987,7 @@ class FlashCausalLM(Model): break else: stopped = False + index += n_accepted_ids _next_token_ids = _next_token_ids[:len(_next_token_ids) - left] # Shard generations @@ -1003,7 +997,7 @@ class FlashCausalLM(Model): # Decode generated tokens # Remove potentially accepted ids that do not respect # the stopping_criteria - _ids = all_input_ids[:len(all_input_ids)-left] + _ids = all_input_ids output_text, _, _ = self.decode_token( _ids, prefix_offset=len(_ids) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 17d2ea9b..8552960d 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, Generation +from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -22,6 +23,7 @@ class Model(ABC): rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, + speculate: Optional[int] = None, ): self.model = model.eval() self.tokenizer = tokenizer @@ -33,6 +35,10 @@ class Model(ABC): self.world_size = world_size self.sliding_window = sliding_window + if speculate is None: + speculate = get_speculate() + self.speculate = speculate + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -50,6 +56,7 @@ class Model(ABC): dtype=str(self.dtype), device_type=self.device.type, window_size=self.sliding_window, + speculate=self.speculate ) @property