feat: propagate max_concurrent_requests to queue state entries instead of hardcoded 128 in backend/v2

This commit is contained in:
Venkat Raman 2024-09-26 18:02:42 +02:00
parent 0aa66d693a
commit 77ddc8309d
8 changed files with 1386 additions and 16 deletions

View File

@ -0,0 +1,647 @@
// This file is @generated by prost-build.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct HealthRequest {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct HealthResponse {}
/// / Empty request
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct InfoRequest {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct InfoResponse {
#[prost(bool, tag = "1")]
pub requires_padding: bool,
#[prost(string, tag = "2")]
pub dtype: ::prost::alloc::string::String,
#[prost(string, tag = "3")]
pub device_type: ::prost::alloc::string::String,
#[prost(uint32, optional, tag = "4")]
pub window_size: ::core::option::Option<u32>,
#[prost(uint32, tag = "5")]
pub speculate: u32,
}
/// / Empty request
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ServiceDiscoveryRequest {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ServiceDiscoveryResponse {
/// / Other shards urls
#[prost(string, repeated, tag = "1")]
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ClearCacheRequest {
/// / Optional batch id
#[prost(uint64, optional, tag = "1")]
pub id: ::core::option::Option<u64>,
}
/// / Empty response
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ClearCacheResponse {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct NextTokenChooserParameters {
/// / exponential scaling output probability distribution
#[prost(float, tag = "1")]
pub temperature: f32,
/// / restricting to the k highest probability elements
#[prost(uint32, tag = "2")]
pub top_k: u32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
#[prost(float, tag = "3")]
pub top_p: f32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
#[prost(float, tag = "4")]
pub typical_p: f32,
/// / apply sampling on the logits
#[prost(bool, tag = "5")]
pub do_sample: bool,
/// / random seed for sampling
#[prost(uint64, tag = "6")]
pub seed: u64,
/// / repetition penalty
#[prost(float, tag = "7")]
pub repetition_penalty: f32,
/// / frequency penalty
#[prost(float, tag = "9")]
pub frequency_penalty: f32,
/// / token watermarking using "A Watermark for Large Language Models"
#[prost(bool, tag = "8")]
pub watermark: bool,
/// / grammar (applied if not empty)
#[prost(string, tag = "10")]
pub grammar: ::prost::alloc::string::String,
/// / grammar type
#[prost(enumeration = "GrammarType", tag = "11")]
pub grammar_type: i32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct StoppingCriteriaParameters {
/// / Maximum number of generated tokens
#[prost(uint32, tag = "1")]
pub max_new_tokens: u32,
/// / Optional stopping sequences
#[prost(string, repeated, tag = "2")]
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
/// / Ignore end of sequence token
/// / used for benchmarking
#[prost(bool, tag = "3")]
pub ignore_eos_token: bool,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Request {
/// / Request ID
#[prost(uint64, tag = "1")]
pub id: u64,
/// / The generation context
#[prost(string, tag = "2")]
pub inputs: ::prost::alloc::string::String,
/// / Context truncation
#[prost(uint32, tag = "3")]
pub truncate: u32,
/// / Next Token Chooser Parameters
#[prost(message, optional, tag = "4")]
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
/// / Stopping Criteria Parameters
#[prost(message, optional, tag = "5")]
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
/// / Return prefill logprobs
#[prost(bool, tag = "6")]
pub prefill_logprobs: bool,
/// / Return most likely n tokens
#[prost(uint32, tag = "7")]
pub top_n_tokens: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Batch {
/// / Batch ID
#[prost(uint64, tag = "1")]
pub id: u64,
/// / Individual requests
#[prost(message, repeated, tag = "2")]
pub requests: ::prost::alloc::vec::Vec<Request>,
/// / Batch size (==len(requests))
#[prost(uint32, tag = "3")]
pub size: u32,
/// / Maximum number of tokens this batch will grow to
#[prost(uint32, tag = "4")]
pub max_tokens: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CachedBatch {
/// / Batch ID
#[prost(uint64, tag = "1")]
pub id: u64,
/// / Individual requests ids
#[prost(uint64, repeated, tag = "2")]
pub request_ids: ::prost::alloc::vec::Vec<u64>,
/// / Batch size (==len(requests))
#[prost(uint32, tag = "3")]
pub size: u32,
/// / Maximum number of tokens this batch will grow to
#[prost(uint32, tag = "4")]
pub max_tokens: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct GeneratedText {
/// / Output
#[prost(string, tag = "1")]
pub text: ::prost::alloc::string::String,
/// / Number of generated tokens
#[prost(uint32, tag = "2")]
pub generated_tokens: u32,
/// / Finish reason
#[prost(enumeration = "FinishReason", tag = "3")]
pub finish_reason: i32,
/// / Seed
#[prost(uint64, optional, tag = "4")]
pub seed: ::core::option::Option<u64>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Tokens {
/// / Token IDs
#[prost(uint32, repeated, tag = "1")]
pub ids: ::prost::alloc::vec::Vec<u32>,
/// / Logprobs
#[prost(float, repeated, tag = "2")]
pub logprobs: ::prost::alloc::vec::Vec<f32>,
/// / tokens
#[prost(string, repeated, tag = "3")]
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
/// / special
#[prost(bool, repeated, tag = "4")]
pub is_special: ::prost::alloc::vec::Vec<bool>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Generation {
/// / Request ID
#[prost(uint64, tag = "1")]
pub request_id: u64,
/// / Prefill tokens (optional)
#[prost(message, optional, tag = "2")]
pub prefill_tokens: ::core::option::Option<Tokens>,
#[prost(message, optional, tag = "3")]
pub tokens: ::core::option::Option<Tokens>,
/// / Complete generated text
#[prost(message, optional, tag = "4")]
pub generated_text: ::core::option::Option<GeneratedText>,
/// / Top tokens
#[prost(message, repeated, tag = "5")]
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FilterBatchRequest {
/// / Batch ID
#[prost(uint64, tag = "1")]
pub batch_id: u64,
/// / Requests to keep
#[prost(uint64, repeated, tag = "2")]
pub request_ids: ::prost::alloc::vec::Vec<u64>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FilterBatchResponse {
/// / Filtered Batch (cached)
#[prost(message, optional, tag = "1")]
pub batch: ::core::option::Option<CachedBatch>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PrefillRequest {
/// / Batch
#[prost(message, optional, tag = "1")]
pub batch: ::core::option::Option<Batch>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PrefillResponse {
/// / Generation
#[prost(message, repeated, tag = "1")]
pub generations: ::prost::alloc::vec::Vec<Generation>,
/// / Next batch (cached)
#[prost(message, optional, tag = "2")]
pub batch: ::core::option::Option<CachedBatch>,
/// / Forward elapsed time in nanoseconds
#[prost(uint64, tag = "3")]
pub forward_ns: u64,
/// / Decode elapsed time in nanoseconds
#[prost(uint64, tag = "4")]
pub decode_ns: u64,
/// / Total elapsed time in nanoseconds
#[prost(uint64, tag = "5")]
pub total_ns: u64,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DecodeRequest {
/// / Cached batches
#[prost(message, repeated, tag = "1")]
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DecodeResponse {
/// / Decodes
#[prost(message, repeated, tag = "1")]
pub generations: ::prost::alloc::vec::Vec<Generation>,
/// / Next batch (cached)
#[prost(message, optional, tag = "2")]
pub batch: ::core::option::Option<CachedBatch>,
/// / Forward elapsed time in nanoseconds
#[prost(uint64, tag = "3")]
pub forward_ns: u64,
/// / Decode elapsed time in nanoseconds
#[prost(uint64, tag = "4")]
pub decode_ns: u64,
/// / Total elapsed time in nanoseconds
#[prost(uint64, tag = "5")]
pub total_ns: u64,
/// / Concatenate elapsed time in nanoseconds
#[prost(uint64, optional, tag = "6")]
pub concat_ns: ::core::option::Option<u64>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct WarmupRequest {
/// / Batch to warmup on
#[prost(message, optional, tag = "1")]
pub batch: ::core::option::Option<Batch>,
#[prost(uint32, tag = "2")]
pub max_input_length: u32,
#[prost(uint32, tag = "3")]
pub max_prefill_tokens: u32,
#[prost(uint32, tag = "4")]
pub max_total_tokens: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct WarmupResponse {
/// / Maximum number of tokens supported by the model
#[prost(uint32, optional, tag = "1")]
pub max_supported_total_tokens: ::core::option::Option<u32>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum GrammarType {
None = 0,
Json = 1,
Regex = 2,
}
impl GrammarType {
/// String value of the enum field names used in the ProtoBuf definition.
///
/// The values are not transformed in any way and thus are considered stable
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
GrammarType::None => "GRAMMAR_TYPE_NONE",
GrammarType::Json => "GRAMMAR_TYPE_JSON",
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"GRAMMAR_TYPE_NONE" => Some(Self::None),
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum FinishReason {
Length = 0,
EosToken = 1,
StopSequence = 2,
}
impl FinishReason {
/// String value of the enum field names used in the ProtoBuf definition.
///
/// The values are not transformed in any way and thus are considered stable
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
FinishReason::Length => "FINISH_REASON_LENGTH",
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"FINISH_REASON_LENGTH" => Some(Self::Length),
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
_ => None,
}
}
}
/// Generated client implementations.
pub mod text_generation_service_client {
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
use tonic::codegen::*;
use tonic::codegen::http::Uri;
#[derive(Debug, Clone)]
pub struct TextGenerationServiceClient<T> {
inner: tonic::client::Grpc<T>,
}
impl TextGenerationServiceClient<tonic::transport::Channel> {
/// Attempt to create a new client by connecting to a given endpoint.
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
where
D: TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
Ok(Self::new(conn))
}
}
impl<T> TextGenerationServiceClient<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(inner: T) -> Self {
let inner = tonic::client::Grpc::new(inner);
Self { inner }
}
pub fn with_origin(inner: T, origin: Uri) -> Self {
let inner = tonic::client::Grpc::with_origin(inner, origin);
Self { inner }
}
pub fn with_interceptor<F>(
inner: T,
interceptor: F,
) -> TextGenerationServiceClient<InterceptedService<T, F>>
where
F: tonic::service::Interceptor,
T::ResponseBody: Default,
T: tonic::codegen::Service<
http::Request<tonic::body::BoxBody>,
Response = http::Response<
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
>,
>,
<T as tonic::codegen::Service<
http::Request<tonic::body::BoxBody>,
>>::Error: Into<StdError> + Send + Sync,
{
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
}
/// Compress requests with the given encoding.
///
/// This requires the server to support it otherwise it might respond with an
/// error.
#[must_use]
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.send_compressed(encoding);
self
}
/// Enable decompressing responses.
#[must_use]
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.accept_compressed(encoding);
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
self
}
/// / Model Info
pub async fn info(
&mut self,
request: impl tonic::IntoRequest<super::InfoRequest>,
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/Info",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info"));
self.inner.unary(req, path, codec).await
}
/// / Service discovery
pub async fn service_discovery(
&mut self,
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
) -> std::result::Result<
tonic::Response<super::ServiceDiscoveryResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/ServiceDiscovery",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(
GrpcMethod::new(
"generate.v2.TextGenerationService",
"ServiceDiscovery",
),
);
self.inner.unary(req, path, codec).await
}
/// / Empties batch cache
pub async fn clear_cache(
&mut self,
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
) -> std::result::Result<
tonic::Response<super::ClearCacheResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/ClearCache",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(
GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"),
);
self.inner.unary(req, path, codec).await
}
/// / Remove requests from a cached batch
pub async fn filter_batch(
&mut self,
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
) -> std::result::Result<
tonic::Response<super::FilterBatchResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/FilterBatch",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(
GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"),
);
self.inner.unary(req, path, codec).await
}
/// / Warmup the model and compute max cache size
pub async fn warmup(
&mut self,
request: impl tonic::IntoRequest<super::WarmupRequest>,
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/Warmup",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup"));
self.inner.unary(req, path, codec).await
}
/// / Prefill batch and decode first token
pub async fn prefill(
&mut self,
request: impl tonic::IntoRequest<super::PrefillRequest>,
) -> std::result::Result<
tonic::Response<super::PrefillResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/Prefill",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill"));
self.inner.unary(req, path, codec).await
}
/// / Decode token for a list of prefilled batches
pub async fn decode(
&mut self,
request: impl tonic::IntoRequest<super::DecodeRequest>,
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/Decode",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode"));
self.inner.unary(req, path, codec).await
}
/// / Health check
pub async fn health(
&mut self,
request: impl tonic::IntoRequest<super::HealthRequest>,
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v2.TextGenerationService/Health",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health"));
self.inner.unary(req, path, codec).await
}
}
}

View File

@ -0,0 +1,6 @@
// This file is @generated by prost-build.
pub mod generate {
pub mod v2 {
include!("generate.v2.rs");
}
}

View File

@ -0,0 +1,703 @@
// This file is @generated by prost-build.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct HealthRequest {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct HealthResponse {}
/// / Empty request
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct InfoRequest {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct InfoResponse {
#[prost(bool, tag = "1")]
pub requires_padding: bool,
#[prost(string, tag = "2")]
pub dtype: ::prost::alloc::string::String,
#[prost(string, tag = "3")]
pub device_type: ::prost::alloc::string::String,
#[prost(uint32, optional, tag = "4")]
pub window_size: ::core::option::Option<u32>,
#[prost(uint32, tag = "5")]
pub speculate: u32,
}
/// / Empty request
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ServiceDiscoveryRequest {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ServiceDiscoveryResponse {
/// / Other shards urls
#[prost(string, repeated, tag = "1")]
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ClearCacheRequest {
/// / Optional batch id
#[prost(uint64, optional, tag = "1")]
pub id: ::core::option::Option<u64>,
}
/// / Empty response
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ClearCacheResponse {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Image {
/// / Binary image data.
#[prost(bytes = "vec", tag = "1")]
pub data: ::prost::alloc::vec::Vec<u8>,
/// / Image MIME type.
#[prost(string, tag = "2")]
pub mimetype: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct InputChunk {
#[prost(oneof = "input_chunk::Chunk", tags = "1, 2")]
pub chunk: ::core::option::Option<input_chunk::Chunk>,
}
/// Nested message and enum types in `InputChunk`.
pub mod input_chunk {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Chunk {
/// / Plain text data
#[prost(string, tag = "1")]
Text(::prost::alloc::string::String),
/// / Image data
#[prost(message, tag = "2")]
Image(super::Image),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Input {
#[prost(message, repeated, tag = "1")]
pub chunks: ::prost::alloc::vec::Vec<InputChunk>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct NextTokenChooserParameters {
/// / exponential scaling output probability distribution
#[prost(float, tag = "1")]
pub temperature: f32,
/// / restricting to the k highest probability elements
#[prost(uint32, tag = "2")]
pub top_k: u32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
#[prost(float, tag = "3")]
pub top_p: f32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
#[prost(float, tag = "4")]
pub typical_p: f32,
/// / apply sampling on the logits
#[prost(bool, tag = "5")]
pub do_sample: bool,
/// / random seed for sampling
#[prost(uint64, tag = "6")]
pub seed: u64,
/// / repetition penalty
#[prost(float, tag = "7")]
pub repetition_penalty: f32,
/// / frequency penalty
#[prost(float, tag = "9")]
pub frequency_penalty: f32,
/// / token watermarking using "A Watermark for Large Language Models"
#[prost(bool, tag = "8")]
pub watermark: bool,
/// / grammar (applied if not empty)
#[prost(string, tag = "10")]
pub grammar: ::prost::alloc::string::String,
/// / grammar type
#[prost(enumeration = "GrammarType", tag = "11")]
pub grammar_type: i32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct StoppingCriteriaParameters {
/// / Maximum number of generated tokens
#[prost(uint32, tag = "1")]
pub max_new_tokens: u32,
/// / Optional stopping sequences
#[prost(string, repeated, tag = "2")]
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
/// / Ignore end of sequence token
/// / used for benchmarking
#[prost(bool, tag = "3")]
pub ignore_eos_token: bool,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Request {
/// / Request ID
#[prost(uint64, tag = "1")]
pub id: u64,
/// / The generation context as chunks
#[prost(message, optional, tag = "8")]
pub input_chunks: ::core::option::Option<Input>,
/// / The generation context, stringified input_chunks
#[prost(string, tag = "2")]
pub inputs: ::prost::alloc::string::String,
/// / Context truncation
#[prost(uint32, tag = "3")]
pub truncate: u32,
/// / Next Token Chooser Parameters
#[prost(message, optional, tag = "4")]
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
/// / Stopping Criteria Parameters
#[prost(message, optional, tag = "5")]
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
/// / Return prefill logprobs
#[prost(bool, tag = "6")]
pub prefill_logprobs: bool,
/// / Return most likely n tokens
#[prost(uint32, tag = "7")]
pub top_n_tokens: u32,
/// / Paged attention blocks
#[prost(uint32, repeated, tag = "9")]
pub blocks: ::prost::alloc::vec::Vec<u32>,
/// / Paged attention slots
#[prost(uint32, repeated, tag = "10")]
pub slots: ::prost::alloc::vec::Vec<u32>,
/// / LORA adapter index
#[prost(string, optional, tag = "11")]
pub adapter_id: ::core::option::Option<::prost::alloc::string::String>,
/// / Prefix length that can be retrieved from the KV cache.
#[prost(uint32, tag = "12")]
pub prefix_len: u32,
/// / Context truncation
#[prost(bool, tag = "13")]
pub add_special_tokens: bool,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Batch {
/// / Batch ID
#[prost(uint64, tag = "1")]
pub id: u64,
/// / Individual requests
#[prost(message, repeated, tag = "2")]
pub requests: ::prost::alloc::vec::Vec<Request>,
/// / Batch size (==len(requests))
#[prost(uint32, tag = "3")]
pub size: u32,
/// / Maximum number of tokens this batch will grow to
#[prost(uint32, tag = "4")]
pub max_tokens: u32,
/// / Maximum number of Paged Attention blocks
#[prost(uint32, tag = "5")]
pub max_blocks: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CachedBatch {
/// / Batch ID
#[prost(uint64, tag = "1")]
pub id: u64,
/// / Individual requests ids
#[prost(uint64, repeated, tag = "2")]
pub request_ids: ::prost::alloc::vec::Vec<u64>,
/// / Batch size (==len(requests))
#[prost(uint32, tag = "3")]
pub size: u32,
/// / Maximum number of tokens this batch will grow to
#[prost(uint32, tag = "4")]
pub max_tokens: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct GeneratedText {
/// / Output
#[prost(string, tag = "1")]
pub text: ::prost::alloc::string::String,
/// / Number of generated tokens
#[prost(uint32, tag = "2")]
pub generated_tokens: u32,
/// / Finish reason
#[prost(enumeration = "FinishReason", tag = "3")]
pub finish_reason: i32,
/// / Seed
#[prost(uint64, optional, tag = "4")]
pub seed: ::core::option::Option<u64>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Tokens {
/// / Token IDs
#[prost(uint32, repeated, tag = "1")]
pub ids: ::prost::alloc::vec::Vec<u32>,
/// / Logprobs
#[prost(float, repeated, tag = "2")]
pub logprobs: ::prost::alloc::vec::Vec<f32>,
/// / tokens
#[prost(string, repeated, tag = "3")]
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
/// / special
#[prost(bool, repeated, tag = "4")]
pub is_special: ::prost::alloc::vec::Vec<bool>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Generation {
/// / Request ID
#[prost(uint64, tag = "1")]
pub request_id: u64,
/// / Prefill tokens (optional)
#[prost(message, optional, tag = "2")]
pub prefill_tokens: ::core::option::Option<Tokens>,
#[prost(message, optional, tag = "3")]
pub tokens: ::core::option::Option<Tokens>,
/// / Complete generated text
#[prost(message, optional, tag = "4")]
pub generated_text: ::core::option::Option<GeneratedText>,
/// / Top tokens
#[prost(message, repeated, tag = "5")]
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FilterBatchRequest {
/// / Batch ID
#[prost(uint64, tag = "1")]
pub batch_id: u64,
/// / Requests to keep
#[prost(uint64, repeated, tag = "2")]
pub request_ids: ::prost::alloc::vec::Vec<u64>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FilterBatchResponse {
/// / Filtered Batch (cached)
#[prost(message, optional, tag = "1")]
pub batch: ::core::option::Option<CachedBatch>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PrefillRequest {
/// / Batch
#[prost(message, optional, tag = "1")]
pub batch: ::core::option::Option<Batch>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PrefillResponse {
/// / Generation
#[prost(message, repeated, tag = "1")]
pub generations: ::prost::alloc::vec::Vec<Generation>,
/// / Next batch (cached)
#[prost(message, optional, tag = "2")]
pub batch: ::core::option::Option<CachedBatch>,
/// / Forward elapsed time in nanoseconds
#[prost(uint64, tag = "3")]
pub forward_ns: u64,
/// / Decode elapsed time in nanoseconds
#[prost(uint64, tag = "4")]
pub decode_ns: u64,
/// / Total elapsed time in nanoseconds
#[prost(uint64, tag = "5")]
pub total_ns: u64,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DecodeRequest {
/// / Cached batches
#[prost(message, repeated, tag = "1")]
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DecodeResponse {
/// / Decodes
#[prost(message, repeated, tag = "1")]
pub generations: ::prost::alloc::vec::Vec<Generation>,
/// / Next batch (cached)
#[prost(message, optional, tag = "2")]
pub batch: ::core::option::Option<CachedBatch>,
/// / Forward elapsed time in nanoseconds
#[prost(uint64, tag = "3")]
pub forward_ns: u64,
/// / Decode elapsed time in nanoseconds
#[prost(uint64, tag = "4")]
pub decode_ns: u64,
/// / Total elapsed time in nanoseconds
#[prost(uint64, tag = "5")]
pub total_ns: u64,
/// / Concatenate elapsed time in nanoseconds
#[prost(uint64, optional, tag = "6")]
pub concat_ns: ::core::option::Option<u64>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct WarmupRequest {
/// / Batch to warmup on
#[prost(message, optional, tag = "1")]
pub batch: ::core::option::Option<Batch>,
#[prost(uint32, tag = "2")]
pub max_input_length: u32,
#[prost(uint32, tag = "3")]
pub max_prefill_tokens: u32,
#[prost(uint32, tag = "4")]
pub max_total_tokens: u32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct WarmupResponse {
/// / Maximum number of tokens supported by the model
#[prost(uint32, optional, tag = "1")]
pub max_supported_total_tokens: ::core::option::Option<u32>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum GrammarType {
None = 0,
Json = 1,
Regex = 2,
}
impl GrammarType {
/// String value of the enum field names used in the ProtoBuf definition.
///
/// The values are not transformed in any way and thus are considered stable
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
GrammarType::None => "GRAMMAR_TYPE_NONE",
GrammarType::Json => "GRAMMAR_TYPE_JSON",
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"GRAMMAR_TYPE_NONE" => Some(Self::None),
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum FinishReason {
Length = 0,
EosToken = 1,
StopSequence = 2,
}
impl FinishReason {
/// String value of the enum field names used in the ProtoBuf definition.
///
/// The values are not transformed in any way and thus are considered stable
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
FinishReason::Length => "FINISH_REASON_LENGTH",
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"FINISH_REASON_LENGTH" => Some(Self::Length),
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
_ => None,
}
}
}
/// Generated client implementations.
pub mod text_generation_service_client {
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
use tonic::codegen::*;
use tonic::codegen::http::Uri;
#[derive(Debug, Clone)]
pub struct TextGenerationServiceClient<T> {
inner: tonic::client::Grpc<T>,
}
impl TextGenerationServiceClient<tonic::transport::Channel> {
/// Attempt to create a new client by connecting to a given endpoint.
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
where
D: TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
Ok(Self::new(conn))
}
}
impl<T> TextGenerationServiceClient<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(inner: T) -> Self {
let inner = tonic::client::Grpc::new(inner);
Self { inner }
}
pub fn with_origin(inner: T, origin: Uri) -> Self {
let inner = tonic::client::Grpc::with_origin(inner, origin);
Self { inner }
}
pub fn with_interceptor<F>(
inner: T,
interceptor: F,
) -> TextGenerationServiceClient<InterceptedService<T, F>>
where
F: tonic::service::Interceptor,
T::ResponseBody: Default,
T: tonic::codegen::Service<
http::Request<tonic::body::BoxBody>,
Response = http::Response<
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
>,
>,
<T as tonic::codegen::Service<
http::Request<tonic::body::BoxBody>,
>>::Error: Into<StdError> + Send + Sync,
{
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
}
/// Compress requests with the given encoding.
///
/// This requires the server to support it otherwise it might respond with an
/// error.
#[must_use]
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.send_compressed(encoding);
self
}
/// Enable decompressing responses.
#[must_use]
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.accept_compressed(encoding);
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
self
}
/// / Model Info
pub async fn info(
&mut self,
request: impl tonic::IntoRequest<super::InfoRequest>,
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/Info",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Info"));
self.inner.unary(req, path, codec).await
}
/// / Service discovery
pub async fn service_discovery(
&mut self,
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
) -> std::result::Result<
tonic::Response<super::ServiceDiscoveryResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/ServiceDiscovery",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(
GrpcMethod::new(
"generate.v3.TextGenerationService",
"ServiceDiscovery",
),
);
self.inner.unary(req, path, codec).await
}
/// / Empties batch cache
pub async fn clear_cache(
&mut self,
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
) -> std::result::Result<
tonic::Response<super::ClearCacheResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/ClearCache",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(
GrpcMethod::new("generate.v3.TextGenerationService", "ClearCache"),
);
self.inner.unary(req, path, codec).await
}
/// / Remove requests from a cached batch
pub async fn filter_batch(
&mut self,
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
) -> std::result::Result<
tonic::Response<super::FilterBatchResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/FilterBatch",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(
GrpcMethod::new("generate.v3.TextGenerationService", "FilterBatch"),
);
self.inner.unary(req, path, codec).await
}
/// / Warmup the model and compute max cache size
pub async fn warmup(
&mut self,
request: impl tonic::IntoRequest<super::WarmupRequest>,
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/Warmup",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Warmup"));
self.inner.unary(req, path, codec).await
}
/// / Prefill batch and decode first token
pub async fn prefill(
&mut self,
request: impl tonic::IntoRequest<super::PrefillRequest>,
) -> std::result::Result<
tonic::Response<super::PrefillResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/Prefill",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Prefill"));
self.inner.unary(req, path, codec).await
}
/// / Decode token for a list of prefilled batches
pub async fn decode(
&mut self,
request: impl tonic::IntoRequest<super::DecodeRequest>,
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/Decode",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Decode"));
self.inner.unary(req, path, codec).await
}
/// / Health check
pub async fn health(
&mut self,
request: impl tonic::IntoRequest<super::HealthRequest>,
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/generate.v3.TextGenerationService/Health",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Health"));
self.inner.unary(req, path, codec).await
}
}
}

View File

@ -0,0 +1,6 @@
// This file is @generated by prost-build.
pub mod generate {
pub mod v3 {
include!("generate.v3.rs");
}
}

View File

@ -31,6 +31,7 @@ impl BackendV2 {
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
@ -48,7 +49,7 @@ impl BackendV2 {
} else { } else {
16 16
}; };
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(max_concurrent_requests, requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic

View File

@ -39,6 +39,7 @@ pub async fn connect_backend(
max_batch_total_tokens: Option<u32>, max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_concurrent_requests: usize,
) -> Result<(BackendV2, BackendInfo), V2Error> { ) -> Result<(BackendV2, BackendInfo), V2Error> {
// Helper function // Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| { let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
@ -116,12 +117,13 @@ pub async fn connect_backend(
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
shard_info.window_size, shard_info.window_size,
shard_info.speculate, shard_info.speculate,
); );
tracing::info!("Using backend V3"); tracing::info!("Using backend V2");
Ok((backend, backend_info)) Ok((backend, backend_info))
} }

View File

@ -167,6 +167,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
max_concurrent_requests,
) )
.await?; .await?;

View File

@ -39,6 +39,7 @@ pub(crate) struct Queue {
impl Queue { impl Queue {
pub(crate) fn new( pub(crate) fn new(
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
@ -49,6 +50,7 @@ impl Queue {
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task( tokio::spawn(queue_task(
max_concurrent_requests,
requires_padding, requires_padding,
block_size, block_size,
window_size, window_size,
@ -99,13 +101,14 @@ impl Queue {
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task( async fn queue_task(
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size, window_size, speculate); let mut state = State::new(max_concurrent_requests, requires_padding, block_size, window_size, speculate);
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
@ -157,13 +160,14 @@ struct State {
impl State { impl State {
fn new( fn new(
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(max_concurrent_requests),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
@ -452,7 +456,7 @@ mod tests {
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(128, false, 1, None, 0);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -468,7 +472,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(128, false, 1, None, 0);
assert!(state.next_batch(None, None, 1, 1).is_none()); assert!(state.next_batch(None, None, 1, 1).is_none());
assert!(state.next_batch(Some(1), None, 1, 1).is_none()); assert!(state.next_batch(Some(1), None, 1, 1).is_none());
@ -476,7 +480,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(128, false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -508,7 +512,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_max_size() { fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(128, false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -528,7 +532,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_token_budget() { fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(128, false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -561,14 +565,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(128, false, 1, None, 0);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(128, false, 1, None, 0);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -576,7 +580,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(128, false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -609,7 +613,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(128, false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -625,7 +629,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(128, false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -650,7 +654,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2); let queue = Queue::new(128, false, 1, None, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -669,7 +673,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(128, false, 1, None, 0);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);