From 2611c1a55f5b6c367c23812467bad2befb552f25 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 30 Jul 2024 15:27:57 +0200 Subject: [PATCH] Fixing client. --- Cargo.lock | 230 +++++++- Cargo.toml | 9 +- Dockerfile | 2 + Dockerfile_amd | 2 + Dockerfile_intel | 2 + backends/client/build.rs | 35 ++ backends/client/src/lib.rs | 91 +++ backends/client/src/v3/client.rs | 283 +++++++++ backends/client/src/v3/mod.rs | 13 + backends/client/src/v3/sharded_client.rs | 259 +++++++++ router/client/src/v2/pb/generate.v2.rs | 647 --------------------- router/client/src/v2/pb/mod.rs | 6 - router/client/src/v3/pb/generate.v3.rs | 697 ----------------------- router/client/src/v3/pb/mod.rs | 6 - 14 files changed, 917 insertions(+), 1365 deletions(-) create mode 100644 backends/client/build.rs create mode 100644 backends/client/src/lib.rs create mode 100644 backends/client/src/v3/client.rs create mode 100644 backends/client/src/v3/mod.rs create mode 100644 backends/client/src/v3/sharded_client.rs delete mode 100644 router/client/src/v2/pb/generate.v2.rs delete mode 100644 router/client/src/v2/pb/mod.rs delete mode 100644 router/client/src/v3/pb/generate.v3.rs delete mode 100644 router/client/src/v3/pb/mod.rs diff --git a/Cargo.lock b/Cargo.lock index adcef194..92367d1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -200,6 +200,17 @@ dependencies = [ "v_frame", ] +[[package]] +name = "average" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c309b1c7fca12ebeec3ecba29ea917b3a4cb458ccf504df68bb4d8a0ca565a00" +dependencies = [ + "easy-cast", + "float-ord", + "num-traits", +] + [[package]] name = "avif-serialize" version = "0.8.1" @@ -548,6 +559,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "cassowary" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" + [[package]] name = "cc" version = "1.1.7" @@ -628,7 +645,7 @@ version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.72", @@ -752,6 +769,31 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.6.0", + "crossterm_winapi", + "libc", + "mio 0.8.11", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -955,6 +997,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "easy-cast" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6" +dependencies = [ + "libm", +] + [[package]] name = "either" version = "1.13.0" @@ -1058,6 +1109,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + [[package]] name = "float_eq" version = "1.0.1" @@ -1335,6 +1392,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -1654,6 +1717,12 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + [[package]] name = "init-tracing-opentelemetry" version = "0.14.1" @@ -1840,6 +1909,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "libredox" version = "0.1.3" @@ -2039,6 +2114,18 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.1" @@ -2304,6 +2391,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -2601,6 +2689,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "papergrid" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2ccbe15f2b6db62f9a9871642746427e297b0ceb85f9a7f1ee5ff47d184d0c8" +dependencies = [ + "bytecount", + "fnv", + "unicode-width", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -2807,7 +2906,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -2925,6 +3024,23 @@ dependencies = [ "getrandom", ] +[[package]] +name = "ratatui" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad" +dependencies = [ + "bitflags 2.6.0", + "cassowary", + "crossterm", + "indoc", + "itertools 0.11.0", + "paste", + "strum", + "unicode-segmentation", + "unicode-width", +] + [[package]] name = "rav1e" version = "0.7.1" @@ -3489,6 +3605,27 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" +dependencies = [ + "libc", + "mio 0.8.11", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -3586,6 +3723,28 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.72", +] + [[package]] name = "subtle" version = "2.6.1" @@ -3669,12 +3828,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck", + "heck 0.5.0", "pkg-config", "toml", "version-compare", ] +[[package]] +name = "tabled" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfe9c3632da101aba5131ed63f9eed38665f8b3c68703a6bb18124835c1a5d22" +dependencies = [ + "papergrid", + "tabled_derive", + "unicode-width", +] + +[[package]] +name = "tabled_derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f688a08b54f4f02f0a3c382aefdb7884d3d69609f785bd253dc033243e3fe4" +dependencies = [ + "heck 0.4.1", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "target-lexicon" version = "0.12.15" @@ -3724,6 +3907,45 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "text-generation-benchmark" +version = "2.2.1-dev0" +dependencies = [ + "average", + "clap", + "crossterm", + "float-ord", + "hf-hub", + "ratatui", + "serde", + "serde_json", + "tabled", + "text-generation-client", + "thiserror", + "tokenizers", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "text-generation-client" +version = "2.2.1-dev0" +dependencies = [ + "async-trait", + "base64 0.22.1", + "futures", + "grpc-metadata", + "prost 0.12.6", + "prost-build", + "thiserror", + "tokio", + "tonic 0.10.2", + "tonic-build", + "tower", + "tracing", +] + [[package]] name = "text-generation-launcher" version = "2.2.1-dev0" @@ -3970,7 +4192,7 @@ dependencies = [ "backtrace", "bytes", "libc", - "mio", + "mio 1.0.1", "parking_lot", "pin-project-lite", "signal-hook-registry", diff --git a/Cargo.toml b/Cargo.toml index bf8a10f0..e2ea1142 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,15 @@ [workspace] members = [ - # "benchmark", + "benchmark", "backends/v3", - # "backends/client", "backends/grpc-metadata", "backends/trtllm", + "backends/client", "launcher" ] default-members = [ - # "benchmark", - # "backends/v3", - # "backends/client", + "benchmark", + "backends/v3", "backends/grpc-metadata", # "backends/trtllm", "launcher" diff --git a/Dockerfile b/Dockerfile index 52393a76..0d57e38d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Dockerfile_amd b/Dockerfile_amd index 0aebeee5..51231638 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Dockerfile_intel b/Dockerfile_intel index 6a803a32..d20f0a01 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/backends/client/build.rs b/backends/client/build.rs new file mode 100644 index 00000000..210cd603 --- /dev/null +++ b/backends/client/build.rs @@ -0,0 +1,35 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/v2/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/v2/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) + .map_err(|e| match e.kind(){ + std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")}, + std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")}, + e => {e} + }).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + fs::create_dir_all("src/v3/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/v3/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/client/src/lib.rs b/backends/client/src/lib.rs new file mode 100644 index 00000000..45bee10c --- /dev/null +++ b/backends/client/src/lib.rs @@ -0,0 +1,91 @@ +//! Text Generation gRPC client library + +pub mod v2; +pub mod v3; + +use async_trait::async_trait; +use base64::{engine::general_purpose::STANDARD, Engine}; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +pub use v3::{Chunk, Image, Input, InputChunk}; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c.chunk { + Some(Chunk::Text(text)) => output.push_str(text), + Some(Chunk::Image(Image { data, mimetype })) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + // We don't create empty chunks, so this should be unreachable. + None => unreachable!("Chunks should never be empty"), + }); + output + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs new file mode 100644 index 00000000..a996b14f --- /dev/null +++ b/backends/client/src/v3/client.rs @@ -0,0 +1,283 @@ +use crate::v3::{pb, Chunk}; +use crate::{ClientError, Result, WARMUP_IMAGE_BASE64}; +/// Single shard Client +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs new file mode 100644 index 00000000..4a1296a2 --- /dev/null +++ b/backends/client/src/v3/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs new file mode 100644 index 00000000..ae8a899b --- /dev/null +++ b/backends/client/src/v3/sharded_client.rs @@ -0,0 +1,259 @@ +/// Multi shard Client +use crate::{v3, Health, ShardInfo}; +use crate::{ClientError, Result}; + +use crate::v3::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; +use v3::client::{DecodeTimings, PrefillTimings}; +use v3::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + adapter_id: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/router/client/src/v2/pb/generate.v2.rs b/router/client/src/v2/pb/generate.v2.rs deleted file mode 100644 index 1a206360..00000000 --- a/router/client/src/v2/pb/generate.v2.rs +++ /dev/null @@ -1,647 +0,0 @@ -// This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthResponse {} -/// / Empty request -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InfoRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InfoResponse { - #[prost(bool, tag = "1")] - pub requires_padding: bool, - #[prost(string, tag = "2")] - pub dtype: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub device_type: ::prost::alloc::string::String, - #[prost(uint32, optional, tag = "4")] - pub window_size: ::core::option::Option, - #[prost(uint32, tag = "5")] - pub speculate: u32, -} -/// / Empty request -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ServiceDiscoveryRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ServiceDiscoveryResponse { - /// / Other shards urls - #[prost(string, repeated, tag = "1")] - pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ClearCacheRequest { - /// / Optional batch id - #[prost(uint64, optional, tag = "1")] - pub id: ::core::option::Option, -} -/// / Empty response -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ClearCacheResponse {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NextTokenChooserParameters { - /// / exponential scaling output probability distribution - #[prost(float, tag = "1")] - pub temperature: f32, - /// / restricting to the k highest probability elements - #[prost(uint32, tag = "2")] - pub top_k: u32, - /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off - #[prost(float, tag = "3")] - pub top_p: f32, - /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off - #[prost(float, tag = "4")] - pub typical_p: f32, - /// / apply sampling on the logits - #[prost(bool, tag = "5")] - pub do_sample: bool, - /// / random seed for sampling - #[prost(uint64, tag = "6")] - pub seed: u64, - /// / repetition penalty - #[prost(float, tag = "7")] - pub repetition_penalty: f32, - /// / frequency penalty - #[prost(float, tag = "9")] - pub frequency_penalty: f32, - /// / token watermarking using "A Watermark for Large Language Models" - #[prost(bool, tag = "8")] - pub watermark: bool, - /// / grammar (applied if not empty) - #[prost(string, tag = "10")] - pub grammar: ::prost::alloc::string::String, - /// / grammar type - #[prost(enumeration = "GrammarType", tag = "11")] - pub grammar_type: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct StoppingCriteriaParameters { - /// / Maximum number of generated tokens - #[prost(uint32, tag = "1")] - pub max_new_tokens: u32, - /// / Optional stopping sequences - #[prost(string, repeated, tag = "2")] - pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// / Ignore end of sequence token - /// / used for benchmarking - #[prost(bool, tag = "3")] - pub ignore_eos_token: bool, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Request { - /// / Request ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / The generation context - #[prost(string, tag = "2")] - pub inputs: ::prost::alloc::string::String, - /// / Context truncation - #[prost(uint32, tag = "3")] - pub truncate: u32, - /// / Next Token Chooser Parameters - #[prost(message, optional, tag = "4")] - pub parameters: ::core::option::Option, - /// / Stopping Criteria Parameters - #[prost(message, optional, tag = "5")] - pub stopping_parameters: ::core::option::Option, - /// / Return prefill logprobs - #[prost(bool, tag = "6")] - pub prefill_logprobs: bool, - /// / Return most likely n tokens - #[prost(uint32, tag = "7")] - pub top_n_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Batch { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / Individual requests - #[prost(message, repeated, tag = "2")] - pub requests: ::prost::alloc::vec::Vec, - /// / Batch size (==len(requests)) - #[prost(uint32, tag = "3")] - pub size: u32, - /// / Maximum number of tokens this batch will grow to - #[prost(uint32, tag = "4")] - pub max_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CachedBatch { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / Individual requests ids - #[prost(uint64, repeated, tag = "2")] - pub request_ids: ::prost::alloc::vec::Vec, - /// / Batch size (==len(requests)) - #[prost(uint32, tag = "3")] - pub size: u32, - /// / Maximum number of tokens this batch will grow to - #[prost(uint32, tag = "4")] - pub max_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GeneratedText { - /// / Output - #[prost(string, tag = "1")] - pub text: ::prost::alloc::string::String, - /// / Number of generated tokens - #[prost(uint32, tag = "2")] - pub generated_tokens: u32, - /// / Finish reason - #[prost(enumeration = "FinishReason", tag = "3")] - pub finish_reason: i32, - /// / Seed - #[prost(uint64, optional, tag = "4")] - pub seed: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Tokens { - /// / Token IDs - #[prost(uint32, repeated, tag = "1")] - pub ids: ::prost::alloc::vec::Vec, - /// / Logprobs - #[prost(float, repeated, tag = "2")] - pub logprobs: ::prost::alloc::vec::Vec, - /// / tokens - #[prost(string, repeated, tag = "3")] - pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// / special - #[prost(bool, repeated, tag = "4")] - pub is_special: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Generation { - /// / Request ID - #[prost(uint64, tag = "1")] - pub request_id: u64, - /// / Prefill tokens (optional) - #[prost(message, optional, tag = "2")] - pub prefill_tokens: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub tokens: ::core::option::Option, - /// / Complete generated text - #[prost(message, optional, tag = "4")] - pub generated_text: ::core::option::Option, - /// / Top tokens - #[prost(message, repeated, tag = "5")] - pub top_tokens: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilterBatchRequest { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub batch_id: u64, - /// / Requests to keep - #[prost(uint64, repeated, tag = "2")] - pub request_ids: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilterBatchResponse { - /// / Filtered Batch (cached) - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrefillRequest { - /// / Batch - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrefillResponse { - /// / Generation - #[prost(message, repeated, tag = "1")] - pub generations: ::prost::alloc::vec::Vec, - /// / Next batch (cached) - #[prost(message, optional, tag = "2")] - pub batch: ::core::option::Option, - /// / Forward elapsed time in nanoseconds - #[prost(uint64, tag = "3")] - pub forward_ns: u64, - /// / Decode elapsed time in nanoseconds - #[prost(uint64, tag = "4")] - pub decode_ns: u64, - /// / Total elapsed time in nanoseconds - #[prost(uint64, tag = "5")] - pub total_ns: u64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DecodeRequest { - /// / Cached batches - #[prost(message, repeated, tag = "1")] - pub batches: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DecodeResponse { - /// / Decodes - #[prost(message, repeated, tag = "1")] - pub generations: ::prost::alloc::vec::Vec, - /// / Next batch (cached) - #[prost(message, optional, tag = "2")] - pub batch: ::core::option::Option, - /// / Forward elapsed time in nanoseconds - #[prost(uint64, tag = "3")] - pub forward_ns: u64, - /// / Decode elapsed time in nanoseconds - #[prost(uint64, tag = "4")] - pub decode_ns: u64, - /// / Total elapsed time in nanoseconds - #[prost(uint64, tag = "5")] - pub total_ns: u64, - /// / Concatenate elapsed time in nanoseconds - #[prost(uint64, optional, tag = "6")] - pub concat_ns: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WarmupRequest { - /// / Batch to warmup on - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, - #[prost(uint32, tag = "2")] - pub max_input_length: u32, - #[prost(uint32, tag = "3")] - pub max_prefill_tokens: u32, - #[prost(uint32, tag = "4")] - pub max_total_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WarmupResponse { - /// / Maximum number of tokens supported by the model - #[prost(uint32, optional, tag = "1")] - pub max_supported_total_tokens: ::core::option::Option, -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum GrammarType { - None = 0, - Json = 1, - Regex = 2, -} -impl GrammarType { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - GrammarType::None => "GRAMMAR_TYPE_NONE", - GrammarType::Json => "GRAMMAR_TYPE_JSON", - GrammarType::Regex => "GRAMMAR_TYPE_REGEX", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "GRAMMAR_TYPE_NONE" => Some(Self::None), - "GRAMMAR_TYPE_JSON" => Some(Self::Json), - "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum FinishReason { - Length = 0, - EosToken = 1, - StopSequence = 2, -} -impl FinishReason { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - FinishReason::Length => "FINISH_REASON_LENGTH", - FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", - FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "FINISH_REASON_LENGTH" => Some(Self::Length), - "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), - "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), - _ => None, - } - } -} -/// Generated client implementations. -pub mod text_generation_service_client { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - use tonic::codegen::http::Uri; - #[derive(Debug, Clone)] - pub struct TextGenerationServiceClient { - inner: tonic::client::Grpc, - } - impl TextGenerationServiceClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl TextGenerationServiceClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> TextGenerationServiceClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - , - >>::Error: Into + Send + Sync, - { - TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - /// / Model Info - pub async fn info( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Info", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info")); - self.inner.unary(req, path, codec).await - } - /// / Service discovery - pub async fn service_discovery( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/ServiceDiscovery", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "generate.v2.TextGenerationService", - "ServiceDiscovery", - ), - ); - self.inner.unary(req, path, codec).await - } - /// / Empties batch cache - pub async fn clear_cache( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/ClearCache", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"), - ); - self.inner.unary(req, path, codec).await - } - /// / Remove requests from a cached batch - pub async fn filter_batch( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/FilterBatch", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"), - ); - self.inner.unary(req, path, codec).await - } - /// / Warmup the model and compute max cache size - pub async fn warmup( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Warmup", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup")); - self.inner.unary(req, path, codec).await - } - /// / Prefill batch and decode first token - pub async fn prefill( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Prefill", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill")); - self.inner.unary(req, path, codec).await - } - /// / Decode token for a list of prefilled batches - pub async fn decode( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Decode", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode")); - self.inner.unary(req, path, codec).await - } - /// / Health check - pub async fn health( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Health", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health")); - self.inner.unary(req, path, codec).await - } - } -} diff --git a/router/client/src/v2/pb/mod.rs b/router/client/src/v2/pb/mod.rs deleted file mode 100644 index 095ead1f..00000000 --- a/router/client/src/v2/pb/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -// This file is @generated by prost-build. -pub mod generate { - pub mod v2 { - include!("generate.v2.rs"); - } -} diff --git a/router/client/src/v3/pb/generate.v3.rs b/router/client/src/v3/pb/generate.v3.rs deleted file mode 100644 index 72315ea3..00000000 --- a/router/client/src/v3/pb/generate.v3.rs +++ /dev/null @@ -1,697 +0,0 @@ -// This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthResponse {} -/// / Empty request -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InfoRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InfoResponse { - #[prost(bool, tag = "1")] - pub requires_padding: bool, - #[prost(string, tag = "2")] - pub dtype: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub device_type: ::prost::alloc::string::String, - #[prost(uint32, optional, tag = "4")] - pub window_size: ::core::option::Option, - #[prost(uint32, tag = "5")] - pub speculate: u32, -} -/// / Empty request -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ServiceDiscoveryRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ServiceDiscoveryResponse { - /// / Other shards urls - #[prost(string, repeated, tag = "1")] - pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ClearCacheRequest { - /// / Optional batch id - #[prost(uint64, optional, tag = "1")] - pub id: ::core::option::Option, -} -/// / Empty response -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ClearCacheResponse {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Image { - /// / Binary image data. - #[prost(bytes = "vec", tag = "1")] - pub data: ::prost::alloc::vec::Vec, - /// / Image MIME type. - #[prost(string, tag = "2")] - pub mimetype: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InputChunk { - #[prost(oneof = "input_chunk::Chunk", tags = "1, 2")] - pub chunk: ::core::option::Option, -} -/// Nested message and enum types in `InputChunk`. -pub mod input_chunk { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Chunk { - /// / Plain text data - #[prost(string, tag = "1")] - Text(::prost::alloc::string::String), - /// / Image data - #[prost(message, tag = "2")] - Image(super::Image), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Input { - #[prost(message, repeated, tag = "1")] - pub chunks: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NextTokenChooserParameters { - /// / exponential scaling output probability distribution - #[prost(float, tag = "1")] - pub temperature: f32, - /// / restricting to the k highest probability elements - #[prost(uint32, tag = "2")] - pub top_k: u32, - /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off - #[prost(float, tag = "3")] - pub top_p: f32, - /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off - #[prost(float, tag = "4")] - pub typical_p: f32, - /// / apply sampling on the logits - #[prost(bool, tag = "5")] - pub do_sample: bool, - /// / random seed for sampling - #[prost(uint64, tag = "6")] - pub seed: u64, - /// / repetition penalty - #[prost(float, tag = "7")] - pub repetition_penalty: f32, - /// / frequency penalty - #[prost(float, tag = "9")] - pub frequency_penalty: f32, - /// / token watermarking using "A Watermark for Large Language Models" - #[prost(bool, tag = "8")] - pub watermark: bool, - /// / grammar (applied if not empty) - #[prost(string, tag = "10")] - pub grammar: ::prost::alloc::string::String, - /// / grammar type - #[prost(enumeration = "GrammarType", tag = "11")] - pub grammar_type: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct StoppingCriteriaParameters { - /// / Maximum number of generated tokens - #[prost(uint32, tag = "1")] - pub max_new_tokens: u32, - /// / Optional stopping sequences - #[prost(string, repeated, tag = "2")] - pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// / Ignore end of sequence token - /// / used for benchmarking - #[prost(bool, tag = "3")] - pub ignore_eos_token: bool, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Request { - /// / Request ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / The generation context as chunks - #[prost(message, optional, tag = "8")] - pub input_chunks: ::core::option::Option, - /// / The generation context, stringified input_chunks - #[prost(string, tag = "2")] - pub inputs: ::prost::alloc::string::String, - /// / Context truncation - #[prost(uint32, tag = "3")] - pub truncate: u32, - /// / Next Token Chooser Parameters - #[prost(message, optional, tag = "4")] - pub parameters: ::core::option::Option, - /// / Stopping Criteria Parameters - #[prost(message, optional, tag = "5")] - pub stopping_parameters: ::core::option::Option, - /// / Return prefill logprobs - #[prost(bool, tag = "6")] - pub prefill_logprobs: bool, - /// / Return most likely n tokens - #[prost(uint32, tag = "7")] - pub top_n_tokens: u32, - /// / Paged attention blocks - #[prost(uint32, repeated, tag = "9")] - pub blocks: ::prost::alloc::vec::Vec, - /// / Paged attention slots - #[prost(uint32, repeated, tag = "10")] - pub slots: ::prost::alloc::vec::Vec, - /// / LORA adapter index - #[prost(string, optional, tag = "11")] - pub adapter_id: ::core::option::Option<::prost::alloc::string::String>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Batch { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / Individual requests - #[prost(message, repeated, tag = "2")] - pub requests: ::prost::alloc::vec::Vec, - /// / Batch size (==len(requests)) - #[prost(uint32, tag = "3")] - pub size: u32, - /// / Maximum number of tokens this batch will grow to - #[prost(uint32, tag = "4")] - pub max_tokens: u32, - /// / Maximum number of Paged Attention blocks - #[prost(uint32, tag = "5")] - pub max_blocks: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CachedBatch { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / Individual requests ids - #[prost(uint64, repeated, tag = "2")] - pub request_ids: ::prost::alloc::vec::Vec, - /// / Batch size (==len(requests)) - #[prost(uint32, tag = "3")] - pub size: u32, - /// / Maximum number of tokens this batch will grow to - #[prost(uint32, tag = "4")] - pub max_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GeneratedText { - /// / Output - #[prost(string, tag = "1")] - pub text: ::prost::alloc::string::String, - /// / Number of generated tokens - #[prost(uint32, tag = "2")] - pub generated_tokens: u32, - /// / Finish reason - #[prost(enumeration = "FinishReason", tag = "3")] - pub finish_reason: i32, - /// / Seed - #[prost(uint64, optional, tag = "4")] - pub seed: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Tokens { - /// / Token IDs - #[prost(uint32, repeated, tag = "1")] - pub ids: ::prost::alloc::vec::Vec, - /// / Logprobs - #[prost(float, repeated, tag = "2")] - pub logprobs: ::prost::alloc::vec::Vec, - /// / tokens - #[prost(string, repeated, tag = "3")] - pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// / special - #[prost(bool, repeated, tag = "4")] - pub is_special: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Generation { - /// / Request ID - #[prost(uint64, tag = "1")] - pub request_id: u64, - /// / Prefill tokens (optional) - #[prost(message, optional, tag = "2")] - pub prefill_tokens: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub tokens: ::core::option::Option, - /// / Complete generated text - #[prost(message, optional, tag = "4")] - pub generated_text: ::core::option::Option, - /// / Top tokens - #[prost(message, repeated, tag = "5")] - pub top_tokens: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilterBatchRequest { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub batch_id: u64, - /// / Requests to keep - #[prost(uint64, repeated, tag = "2")] - pub request_ids: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilterBatchResponse { - /// / Filtered Batch (cached) - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrefillRequest { - /// / Batch - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrefillResponse { - /// / Generation - #[prost(message, repeated, tag = "1")] - pub generations: ::prost::alloc::vec::Vec, - /// / Next batch (cached) - #[prost(message, optional, tag = "2")] - pub batch: ::core::option::Option, - /// / Forward elapsed time in nanoseconds - #[prost(uint64, tag = "3")] - pub forward_ns: u64, - /// / Decode elapsed time in nanoseconds - #[prost(uint64, tag = "4")] - pub decode_ns: u64, - /// / Total elapsed time in nanoseconds - #[prost(uint64, tag = "5")] - pub total_ns: u64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DecodeRequest { - /// / Cached batches - #[prost(message, repeated, tag = "1")] - pub batches: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DecodeResponse { - /// / Decodes - #[prost(message, repeated, tag = "1")] - pub generations: ::prost::alloc::vec::Vec, - /// / Next batch (cached) - #[prost(message, optional, tag = "2")] - pub batch: ::core::option::Option, - /// / Forward elapsed time in nanoseconds - #[prost(uint64, tag = "3")] - pub forward_ns: u64, - /// / Decode elapsed time in nanoseconds - #[prost(uint64, tag = "4")] - pub decode_ns: u64, - /// / Total elapsed time in nanoseconds - #[prost(uint64, tag = "5")] - pub total_ns: u64, - /// / Concatenate elapsed time in nanoseconds - #[prost(uint64, optional, tag = "6")] - pub concat_ns: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WarmupRequest { - /// / Batch to warmup on - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, - #[prost(uint32, tag = "2")] - pub max_input_length: u32, - #[prost(uint32, tag = "3")] - pub max_prefill_tokens: u32, - #[prost(uint32, tag = "4")] - pub max_total_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WarmupResponse { - /// / Maximum number of tokens supported by the model - #[prost(uint32, optional, tag = "1")] - pub max_supported_total_tokens: ::core::option::Option, -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum GrammarType { - None = 0, - Json = 1, - Regex = 2, -} -impl GrammarType { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - GrammarType::None => "GRAMMAR_TYPE_NONE", - GrammarType::Json => "GRAMMAR_TYPE_JSON", - GrammarType::Regex => "GRAMMAR_TYPE_REGEX", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "GRAMMAR_TYPE_NONE" => Some(Self::None), - "GRAMMAR_TYPE_JSON" => Some(Self::Json), - "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum FinishReason { - Length = 0, - EosToken = 1, - StopSequence = 2, -} -impl FinishReason { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - FinishReason::Length => "FINISH_REASON_LENGTH", - FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", - FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "FINISH_REASON_LENGTH" => Some(Self::Length), - "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), - "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), - _ => None, - } - } -} -/// Generated client implementations. -pub mod text_generation_service_client { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - use tonic::codegen::http::Uri; - #[derive(Debug, Clone)] - pub struct TextGenerationServiceClient { - inner: tonic::client::Grpc, - } - impl TextGenerationServiceClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl TextGenerationServiceClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> TextGenerationServiceClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - , - >>::Error: Into + Send + Sync, - { - TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - /// / Model Info - pub async fn info( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/Info", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v3.TextGenerationService", "Info")); - self.inner.unary(req, path, codec).await - } - /// / Service discovery - pub async fn service_discovery( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/ServiceDiscovery", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "generate.v3.TextGenerationService", - "ServiceDiscovery", - ), - ); - self.inner.unary(req, path, codec).await - } - /// / Empties batch cache - pub async fn clear_cache( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/ClearCache", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("generate.v3.TextGenerationService", "ClearCache"), - ); - self.inner.unary(req, path, codec).await - } - /// / Remove requests from a cached batch - pub async fn filter_batch( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/FilterBatch", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("generate.v3.TextGenerationService", "FilterBatch"), - ); - self.inner.unary(req, path, codec).await - } - /// / Warmup the model and compute max cache size - pub async fn warmup( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/Warmup", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v3.TextGenerationService", "Warmup")); - self.inner.unary(req, path, codec).await - } - /// / Prefill batch and decode first token - pub async fn prefill( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/Prefill", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v3.TextGenerationService", "Prefill")); - self.inner.unary(req, path, codec).await - } - /// / Decode token for a list of prefilled batches - pub async fn decode( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/Decode", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v3.TextGenerationService", "Decode")); - self.inner.unary(req, path, codec).await - } - /// / Health check - pub async fn health( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v3.TextGenerationService/Health", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v3.TextGenerationService", "Health")); - self.inner.unary(req, path, codec).await - } - } -} diff --git a/router/client/src/v3/pb/mod.rs b/router/client/src/v3/pb/mod.rs deleted file mode 100644 index b5397d05..00000000 --- a/router/client/src/v3/pb/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -// This file is @generated by prost-build. -pub mod generate { - pub mod v3 { - include!("generate.v3.rs"); - } -}