add window size in proto

This commit is contained in:
OlivierDehaene 2023-09-27 12:20:20 +02:00
parent 2811ec9bff
commit 630e417ca0
8 changed files with 40 additions and 14 deletions

View File

@ -31,6 +31,7 @@ message InfoResponse {
bool requires_padding = 1; bool requires_padding = 1;
string dtype = 2; string dtype = 2;
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4;
} }
/// Empty request /// Empty request

View File

@ -50,10 +50,11 @@ impl Infer {
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, 16); let queue = Queue::new(requires_padding, 16, window_size);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });

View File

@ -2,6 +2,7 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::oneshot; use tokio::sync::oneshot;
@ -33,12 +34,17 @@ pub(crate) struct Queue {
} }
impl Queue { impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); tokio::spawn(queue_task(
requires_padding,
block_size,
window_size,
queue_receiver,
));
Self { queue_sender } Self { queue_sender }
} }
@ -84,9 +90,10 @@ impl Queue {
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
window_size: Option<u32>,
receiver: flume::Receiver<QueueCommand>, receiver: flume::Receiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size); let mut state = State::new(requires_padding, block_size, window_size);
while let Ok(cmd) = receiver.recv_async().await { while let Ok(cmd) = receiver.recv_async().await {
match cmd { match cmd {
@ -126,16 +133,20 @@ struct State {
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
/// Sliding window
window_size: Option<u32>,
} }
impl State { impl State {
fn new(requires_padding: bool, block_size: u32) -> Self { fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
block_size, block_size,
window_size,
} }
} }
@ -204,11 +215,17 @@ impl State {
if self.requires_padding { if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else { } else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
// pad to block size // pad to block size
decode_tokens += decode_tokens +=
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
/ self.block_size)
* self.block_size;
} }
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget

View File

@ -595,6 +595,7 @@ pub async fn run(
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
shard_info.window_size,
generation_health, generation_health,
); );

View File

@ -1,4 +1,4 @@
vllm_commit := e86af624d059969b0fb07b075b1d338bf10c3365 vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
vllm: vllm:
# Clone vllm # Clone vllm

View File

@ -636,12 +636,11 @@ class FlashCausalLM(Model):
device: torch.device, device: torch.device,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
repeat_slots: bool = False, sliding_window: Optional[int] = None,
): ):
self.num_layers = num_layers self.num_layers = num_layers
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.head_size = head_size self.head_size = head_size
self.repeat_slots = repeat_slots
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
@ -651,6 +650,7 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window,
) )
@property @property
@ -665,7 +665,7 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.repeat_slots, self.sliding_window is not None,
self.dtype, self.dtype,
self.device, self.device,
) )
@ -705,7 +705,7 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.repeat_slots, self.sliding_window is not None,
self.dtype, self.dtype,
self.device, self.device,
) )

View File

@ -331,7 +331,7 @@ class FlashMistral(FlashCausalLM):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
repeat_slots=True, sliding_window=config.sliding_window,
) )
@property @property

View File

@ -21,6 +21,7 @@ class Model(ABC):
device: torch.device, device: torch.device,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None,
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -30,6 +31,7 @@ class Model(ABC):
self.device = device self.device = device
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.sliding_window = sliding_window
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
@ -40,10 +42,14 @@ class Model(ABC):
@property @property
def info(self) -> InfoResponse: def info(self) -> InfoResponse:
if self.requires_padding and self.sliding_window is not None:
raise NotImplementedError("sliding_window is not implemented with padding")
return InfoResponse( return InfoResponse(
requires_padding=self.requires_padding, requires_padding=self.requires_padding,
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window,
) )
@property @property