mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add window size in proto
This commit is contained in:
parent
2811ec9bff
commit
630e417ca0
@ -31,6 +31,7 @@ message InfoResponse {
|
|||||||
bool requires_padding = 1;
|
bool requires_padding = 1;
|
||||||
string dtype = 2;
|
string dtype = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
|
optional uint32 window_size = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -50,10 +50,11 @@ impl Infer {
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16);
|
let queue = Queue::new(requires_padding, 16, window_size);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
@ -2,6 +2,7 @@ use crate::infer::InferError;
|
|||||||
use crate::infer::InferStreamResponse;
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
@ -33,12 +34,17 @@ pub(crate) struct Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
|
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
|
tokio::spawn(queue_task(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
queue_receiver,
|
||||||
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
@ -84,9 +90,10 @@ impl Queue {
|
|||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
receiver: flume::Receiver<QueueCommand>,
|
receiver: flume::Receiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size);
|
let mut state = State::new(requires_padding, block_size, window_size);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -126,16 +133,20 @@ struct State {
|
|||||||
|
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
|
/// Sliding window
|
||||||
|
window_size: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(requires_padding: bool, block_size: u32) -> Self {
|
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
|
window_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,11 +215,17 @@ impl State {
|
|||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
} else {
|
} else {
|
||||||
|
let max_new_tokens = match self.window_size {
|
||||||
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
Some(window_size) => min(
|
||||||
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// pad to block size
|
// pad to block size
|
||||||
decode_tokens +=
|
decode_tokens +=
|
||||||
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
|
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
/ self.block_size)
|
|
||||||
* self.block_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|
@ -595,6 +595,7 @@ pub async fn run(
|
|||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
generation_health,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
vllm_commit := e86af624d059969b0fb07b075b1d338bf10c3365
|
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
|
@ -636,12 +636,11 @@ class FlashCausalLM(Model):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
repeat_slots: bool = False,
|
sliding_window: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.repeat_slots = repeat_slots
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -651,6 +650,7 @@ class FlashCausalLM(Model):
|
|||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
sliding_window=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -665,7 +665,7 @@ class FlashCausalLM(Model):
|
|||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.repeat_slots,
|
self.sliding_window is not None,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
@ -705,7 +705,7 @@ class FlashCausalLM(Model):
|
|||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.repeat_slots,
|
self.sliding_window is not None,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
@ -331,7 +331,7 @@ class FlashMistral(FlashCausalLM):
|
|||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
repeat_slots=True,
|
sliding_window=config.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -21,6 +21,7 @@ class Model(ABC):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -30,6 +31,7 @@ class Model(ABC):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
self.has_position_ids = (
|
self.has_position_ids = (
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
@ -40,10 +42,14 @@ class Model(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
|
if self.requires_padding and self.sliding_window is not None:
|
||||||
|
raise NotImplementedError("sliding_window is not implemented with padding")
|
||||||
|
|
||||||
return InfoResponse(
|
return InfoResponse(
|
||||||
requires_padding=self.requires_padding,
|
requires_padding=self.requires_padding,
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
|
window_size=self.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user