Fixing some simple stuff, adding speculate to budget.

This commit is contained in:
Nicolas Patry 2023-12-05 16:38:46 +00:00
parent 5aa3a01971
commit 09839b05f4
6 changed files with 56 additions and 26 deletions

View File

@ -32,7 +32,7 @@ message InfoResponse {
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
optional uint32 speculate = 5;
uint32 speculate = 5;
}
/// Empty request

View File

@ -50,10 +50,11 @@ impl Infer {
max_concurrent_requests: usize,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size);
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});

View File

@ -34,7 +34,12 @@ pub(crate) struct Queue {
}
impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
pub(crate) fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -43,6 +48,7 @@ impl Queue {
requires_padding,
block_size,
window_size,
speculate,
queue_receiver,
));
@ -91,9 +97,10 @@ async fn queue_task(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size, window_size);
let mut state = State::new(requires_padding, block_size, window_size, speculate);
while let Some(cmd) = receiver.recv().await {
match cmd {
@ -136,10 +143,18 @@ struct State {
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
}
impl State {
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
@ -147,6 +162,7 @@ impl State {
requires_padding,
block_size,
window_size,
speculate,
}
}
@ -221,7 +237,7 @@ impl State {
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
} + self.speculate;
// pad to block size
decode_tokens +=
@ -359,7 +375,7 @@ mod tests {
#[test]
fn test_append() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
@ -375,7 +391,7 @@ mod tests {
#[test]
fn test_next_batch_empty() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
@ -383,7 +399,7 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -415,7 +431,7 @@ mod tests {
#[test]
fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -448,14 +464,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
@ -463,7 +479,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -496,7 +512,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -519,9 +535,20 @@ mod tests {
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
assert!(queue.next_batch(None, 1, 1).await.is_none());
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -596,6 +596,7 @@ pub async fn run(
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
);

View File

@ -938,8 +938,6 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
# next_token_ids,
# next_token_logprobs,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
@ -957,8 +955,6 @@ class FlashCausalLM(Model):
do_sample,
seed,
top_n_tokens,
# next_token_id,
# next_token_logprob,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
@ -968,21 +964,18 @@ class FlashCausalLM(Model):
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
next_token_texts = []
left = 0
for j in range(index, index + n_accepted_ids):
# Generated token
all_input_ids.append(next_token_ids[j])
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
index += n_accepted_ids
# Evaluate stopping criteria
left = 0
for j, next_token_id in enumerate(_next_token_ids):
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
@ -994,6 +987,7 @@ class FlashCausalLM(Model):
break
else:
stopped = False
index += n_accepted_ids
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
# Shard generations
@ -1003,7 +997,7 @@ class FlashCausalLM(Model):
# Decode generated tokens
# Remove potentially accepted ids that do not respect
# the stopping_criteria
_ids = all_input_ids[:len(all_input_ids)-left]
_ids = all_input_ids
output_text, _, _ = self.decode_token(
_ids,
prefix_offset=len(_ids)

View File

@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
@ -22,6 +23,7 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
):
self.model = model.eval()
self.tokenizer = tokenizer
@ -33,6 +35,10 @@ class Model(ABC):
self.world_size = world_size
self.sliding_window = sliding_window
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
@ -50,6 +56,7 @@ class Model(ABC):
dtype=str(self.dtype),
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate
)
@property