From 4dbb342fe30d875871ac379768bdc8d26769aca9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:30:31 +0200 Subject: [PATCH] small refactor to make router a bit more agnostic --- proto/v3/generate.proto | 236 +++++++++++++++++++++++++++++ router/client/Cargo.toml | 1 + router/client/build.rs | 18 ++- router/client/src/lib.rs | 37 +++-- router/client/src/pb/.gitignore | 1 - router/client/src/v2/mod.rs | 13 ++ router/client/src/v2/pb/.gitignore | 1 + router/client/src/v3/mod.rs | 13 ++ router/client/src/v3/pb/.gitignore | 1 + router/src/health.rs | 75 +++------ router/src/infer.rs | 67 +++++--- router/src/lib.rs | 2 +- router/src/main.rs | 2 +- router/src/queue.rs | 52 ++++++- router/src/server.rs | 23 +-- router/src/validation.rs | 75 ++++++--- 16 files changed, 478 insertions(+), 139 deletions(-) create mode 100644 proto/v3/generate.proto delete mode 100644 router/client/src/pb/.gitignore create mode 100644 router/client/src/v2/mod.rs create mode 100644 router/client/src/v2/pb/.gitignore create mode 100644 router/client/src/v3/mod.rs create mode 100644 router/client/src/v3/pb/.gitignore diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto new file mode 100644 index 00000000..e594c607 --- /dev/null +++ b/proto/v3/generate.proto @@ -0,0 +1,236 @@ +syntax = "proto3"; + +package generate.v3; + +service TextGenerationService { + /// Model Info + rpc Info (InfoRequest) returns (InfoResponse) {} + /// Service discovery + rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} + /// Empties batch cache + rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup (WarmupRequest) returns (WarmupResponse); + /// Prefill batch and decode first token + rpc Prefill (PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health (HealthRequest) returns (HealthResponse); +} + +message HealthRequest {} +message HealthResponse {} + +/// Empty request +message InfoRequest {} + +message InfoResponse { + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; + optional uint32 window_size = 4; + uint32 speculate = 5; +} + +/// Empty request +message ServiceDiscoveryRequest {} + +message ServiceDiscoveryResponse { + /// Other shards urls + repeated string urls = 1; +} + +message ClearCacheRequest { + /// Optional batch id + optional uint64 id = 1; +} + +/// Empty response +message ClearCacheResponse {} + +enum GrammarType { + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; +} + +message NextTokenChooserParameters { + /// exponential scaling output probability distribution + float temperature = 1; + /// restricting to the k highest probability elements + uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; + /// apply sampling on the logits + bool do_sample = 5; + /// random seed for sampling + uint64 seed = 6; + /// repetition penalty + float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; +} + +message StoppingCriteriaParameters { + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; + /// Ignore end of sequence token + /// used for benchmarking + bool ignore_eos_token = 3; +} + +message Request { + /// Request ID + uint64 id = 1; + /// The generation context + string inputs = 2; + /// Context truncation + uint32 truncate = 3; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; +} + +message Batch { + /// Batch ID + uint64 id = 1; + /// Individual requests + repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; +} + +message CachedBatch { + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; +} + +enum FinishReason { + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; +} + +message GeneratedText { + /// Output + string text = 1; + /// Number of generated tokens + uint32 generated_tokens = 2; + /// Finish reason + FinishReason finish_reason = 3; + /// Seed + optional uint64 seed = 4; +} + +message Tokens { + /// Token IDs + repeated uint32 ids = 1; + /// Logprobs + repeated float logprobs = 2; + /// tokens + repeated string texts = 3; + /// special + repeated bool is_special = 4; +} + +message Generation { + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + Tokens prefill_tokens = 2; + Tokens tokens = 3; + /// Complete generated text + optional GeneratedText generated_text = 4; + /// Top tokens + repeated Tokens top_tokens = 5; +} + +message FilterBatchRequest { + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated uint64 request_ids = 2; +} + +message FilterBatchResponse { + /// Filtered Batch (cached) + CachedBatch batch = 1; +} + + +message PrefillRequest { + /// Batch + Batch batch = 1; +} + +message PrefillResponse { + /// Generation + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; +} + +message DecodeRequest { + /// Cached batches + repeated CachedBatch batches = 1; +} + +message DecodeResponse { + /// Decodes + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; +} + +message WarmupRequest { + /// Batch to warmup on + Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; +} + +message WarmupResponse { + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; +} diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index abbde82d..db423c4b 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -6,6 +6,7 @@ authors.workspace = true homepage.workspace = true [dependencies] +async-trait = "^0.1" base64 = { workspace = true } futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } diff --git a/router/client/build.rs b/router/client/build.rs index 497be545..bcfab74f 100644 --- a/router/client/build.rs +++ b/router/client/build.rs @@ -1,19 +1,31 @@ use std::fs; fn main() -> Result<(), Box> { - println!("cargo:rerun-if-changed=../../proto/generate.proto"); - fs::create_dir("src/pb").unwrap_or(()); + println!("cargo:rerun-if-changed=../../proto/**"); + fs::create_dir_all("src/v2/pb").unwrap_or(()); let mut config = prost_build::Config::new(); config.protoc_arg("--experimental_allow_proto3_optional"); tonic_build::configure() .build_client(true) .build_server(false) - .out_dir("src/pb") + .out_dir("src/v2/pb") .include_file("mod.rs") .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + fs::create_dir_all("src/v3/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/v3/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + Ok(()) } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 9e9ef13b..c0c1274a 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -1,25 +1,32 @@ //! Text Generation gRPC client library -mod client; -#[allow(clippy::derive_partial_eq_without_eq)] -mod pb; -mod sharded_client; +pub mod v2; +pub mod v3; -use base64::{engine::general_purpose::STANDARD, Engine}; -pub use client::Client; -pub use pb::generate::v2::input_chunk::Chunk; -pub use pb::generate::v2::HealthResponse; -pub use pb::generate::v2::Image; -pub use pb::generate::v2::InfoResponse as ShardInfo; -pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, -}; -pub use sharded_client::ShardedClient; +use async_trait::async_trait; use thiserror::Error; use tonic::transport; use tonic::Status; +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + #[derive(Error, Debug, Clone)] pub enum ClientError { #[error("Could not connect to Text Generation server: {0}")] diff --git a/router/client/src/pb/.gitignore b/router/client/src/pb/.gitignore deleted file mode 100644 index 6f5f3d11..00000000 --- a/router/client/src/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.rs diff --git a/router/client/src/v2/mod.rs b/router/client/src/v2/mod.rs new file mode 100644 index 00000000..6b14b9f3 --- /dev/null +++ b/router/client/src/v2/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/router/client/src/v2/pb/.gitignore b/router/client/src/v2/pb/.gitignore new file mode 100644 index 00000000..72e8ffc0 --- /dev/null +++ b/router/client/src/v2/pb/.gitignore @@ -0,0 +1 @@ +* diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs new file mode 100644 index 00000000..7d551c13 --- /dev/null +++ b/router/client/src/v3/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v3::HealthResponse; +pub use pb::generate::v3::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/pb/.gitignore b/router/client/src/v3/pb/.gitignore new file mode 100644 index 00000000..72e8ffc0 --- /dev/null +++ b/router/client/src/v3/pb/.gitignore @@ -0,0 +1 @@ +* diff --git a/router/src/health.rs b/router/src/health.rs index 121255b9..4320c1a4 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,22 +1,18 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use text_generation_client::{ - Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, -}; -use text_generation_client::{Chunk, GrammarType as ProtoGrammarType}; +use text_generation_client::Health; -// Note: Request ids and batch ids cannot collide. -const LIVENESS_ID: u64 = u64::MAX; -const BATCH_ID: u64 = u64::MAX; - -#[derive(Clone, Debug)] -pub(crate) struct Health { - client: ShardedClient, +#[derive(Clone)] +pub(crate) struct HealthCheck { + client: Arc, generation_health: Arc, } -impl Health { - pub(crate) fn new(client: ShardedClient, generation_health: Arc) -> Self { +impl HealthCheck { + pub(crate) fn new( + client: Arc, + generation_health: Arc, + ) -> Self { Self { client, generation_health, @@ -24,52 +20,15 @@ impl Health { } pub(crate) async fn check(&mut self) -> bool { - if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards are answering gRPC calls - self.client.health().await.is_ok() + let value = if self.generation_health.load(Ordering::SeqCst) { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await } else { - // Generation is unhealthy or have not sent any generation request yet - - // Dummy batch of 1 token and 1 generated token - let liveness_request = Request { - id: LIVENESS_ID, - input_chunks: Some(Input { - chunks: vec![Chunk::Text("liveness".into()).into()], - }), - inputs: "liveness".to_string(), - truncate: 10, - prefill_logprobs: false, - parameters: Some(NextTokenChooserParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - typical_p: 1.0, - do_sample: false, - seed: 0, - repetition_penalty: 1.0, - frequency_penalty: 0.0, - watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, - }), - stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: 1, - stop_sequences: vec![], - ignore_eos_token: false, - }), - top_n_tokens: 0, - }; - let batch = Batch { - id: BATCH_ID, - requests: vec![liveness_request], - size: 1, - max_tokens: 2, - }; - // Skips the queue - let value = self.client.prefill(batch).await.is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value + self.client.model_health().await } + .is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + value } } diff --git a/router/src/infer.rs b/router/src/infer.rs index 0410de7d..6279cc5d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,9 +1,9 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, - HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, - TextMessage, Token, + ChatTemplateInputs, ChatTemplateVersions, Entry, FinishReason, GenerateRequest, + GenerateStreamResponse, HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, + PrefillToken, Queue, Text, TextMessage, Token, }; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; @@ -15,9 +15,8 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::{ - Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens, -}; +use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::{v2, ClientError}; use thiserror::Error; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; @@ -232,16 +231,8 @@ impl Infer { while let Some(response) = stream.next().await { match response? { // Add prefill tokens - InferStreamResponse::Prefill(tokens) => { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - result_prefill = tokens - .ids - .into_iter() - .zip(tokens.logprobs.into_iter()) - .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); + InferStreamResponse::Prefill(prefill_tokens) => { + result_prefill = prefill_tokens; } // Push last token InferStreamResponse::Intermediate { token, top_tokens } => { @@ -792,6 +783,16 @@ fn send_responses( let mut stopped = false; if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs.into_iter()) + .zip(prefill_tokens.texts.into_iter()) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + // Send message entry .response_tx @@ -842,7 +843,7 @@ fn send_responses( entry.response_tx.send(Ok(InferStreamResponse::End { token, top_tokens, - generated_text: generated_text.clone(), + generated_text: GeneratedText::from(generated_text.clone()), queued: entry.queue_time, start: entry.batch_time.unwrap(), }))?; @@ -877,10 +878,36 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { }); } +#[derive(Debug)] +pub(crate) struct GeneratedText { + pub(crate) text: String, + pub(crate) generated_tokens: u32, + pub(crate) finish_reason: FinishReason, + pub(crate) seed: Option, +} + +impl From for GeneratedText { + fn from(value: v2::GeneratedText) -> Self { + let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v2_finish_reason { + v2::FinishReason::Length => FinishReason::Length, + v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + v2::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} + #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message - Prefill(Tokens), + Prefill(Vec), // Intermediate messages Intermediate { token: Token, @@ -1355,11 +1382,11 @@ mod tests { chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", input: ChatTemplateInputs { messages: vec![ - TextMessage{ + TextMessage { role: "system".to_string(), content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), }, - TextMessage{ + TextMessage { role: "user".to_string(), content: "How many helicopters can a human eat in one sitting?".to_string(), }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 9b3283df..d687794d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1087,7 +1087,7 @@ pub struct SimpleToken { stop: usize, } -#[derive(Serialize, ToSchema)] +#[derive(Debug, Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] pub(crate) enum FinishReason { diff --git a/router/src/main.rs b/router/src/main.rs index b526367c..08faba40 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -12,7 +12,7 @@ use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; -use text_generation_client::{ClientError, ShardedClient}; +use text_generation_client::{v2::ShardedClient, ClientError}; use text_generation_router::config::Config; use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use thiserror::Error; diff --git a/router/src/queue.rs b/router/src/queue.rs index 40692ffc..705871b8 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -1,12 +1,17 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; -use crate::validation::ValidGenerateRequest; +use crate::validation::{ + ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; use text_generation_client::ChunksToString; use text_generation_client::Input; use text_generation_client::{Batch, Request}; +use text_generation_client::v2::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -285,8 +290,12 @@ impl State { }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), + parameters: Some(NextTokenChooserParameters::from( + entry.request.parameters.clone(), + )), + stopping_parameters: Some(StoppingCriteriaParameters::from( + entry.request.stopping_parameters.clone(), + )), top_n_tokens: entry.request.top_n_tokens, }); // Set batch_time @@ -355,6 +364,43 @@ enum QueueCommand { }, } +impl From for NextTokenChooserParameters { + fn from(value: ValidParameters) -> Self { + let (grammar, grammar_type) = match value.grammar { + None => (String::new(), GrammarType::None), + + Some(grammar) => match grammar { + ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json), + ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex), + }, + }; + + Self { + temperature: value.temperature, + top_k: value.top_k, + top_p: value.top_p, + typical_p: value.typical_p, + do_sample: value.do_sample, + seed: value.seed, + repetition_penalty: value.repetition_penalty, + frequency_penalty: value.frequency_penalty, + watermark: value.watermark, + grammar, + grammar_type: grammar_type.into(), + } + } +} + +impl From for StoppingCriteriaParameters { + fn from(value: ValidStoppingParameters) -> Self { + Self { + max_new_tokens: value.max_new_tokens, + stop_sequences: value.stop_sequences, + ignore_eos_token: value.ignore_eos_token, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index eb7ba2a0..f44c57ef 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,6 +1,6 @@ -use crate::config::Config; /// HTTP Server logic -use crate::health::Health; +use crate::config::Config; +use crate::health::HealthCheck; use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ @@ -34,7 +34,7 @@ use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; -use text_generation_client::{ShardInfo, ShardedClient}; +use text_generation_client::{v2::ShardedClient, ShardInfo}; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -115,7 +115,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), )] #[instrument(skip(health))] /// Health check method -async fn health(mut health: Extension) -> Result<(), (StatusCode, Json)> { +async fn health( + mut health: Extension, +) -> Result<(), (StatusCode, Json)> { match health.check().await { true => Ok(()), false => Err(( @@ -1482,7 +1484,7 @@ pub async fn run( grammar_support, ); let generation_health = Arc::new(AtomicBool::new(false)); - let health_ext = Health::new(client.clone(), generation_health.clone()); + let health_ext = HealthCheck::new(Arc::new(client.clone()), generation_health.clone()); let infer = Infer::new( client, validation, @@ -1719,17 +1721,6 @@ async fn shutdown_signal() { opentelemetry::global::shutdown_tracer_provider(); } -impl From for FinishReason { - fn from(finish_reason: i32) -> Self { - let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap(); - match finish_reason { - text_generation_client::FinishReason::Length => FinishReason::Length, - text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence, - } - } -} - /// Convert to Axum supported formats impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { diff --git a/router/src/validation.rs b/router/src/validation.rs index 863bb99b..c321c33b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,10 +6,6 @@ use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; -use text_generation_client::{ - Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters, - StoppingCriteriaParameters, -}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; // use tokenizers::TruncationDirection; @@ -173,10 +169,6 @@ impl Validation { // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { input_length = input_length.saturating_sub(max_new_tokens as usize); - // return Err(ValidationError::MaxNewTokens( - // self.max_total_tokens - self.max_input_length, - // max_new_tokens, - // )); } Ok(( @@ -327,13 +319,13 @@ impl Validation { // compiler and use that to build the FSM here. // Validate grammar and unpack the grammar and type for the proto message - let (grammar, grammar_type) = match grammar { + let grammar = match grammar { Some(grammar) => { // Ensure that grammar is not set if it's not supported if self.disable_grammar_support { return Err(ValidationError::Grammar); } - match grammar { + let valid_grammar = match grammar { GrammarType::Json(json) => { let json = match json { // if value is a string, we need to parse it again to make sure its @@ -350,20 +342,20 @@ impl Validation { .compile(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; - ( - // Serialize json to string + // Serialize json to string + ValidGrammar::Json( serde_json::to_string(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, - ProtoGrammarType::Json.into(), ) } - GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), - } + GrammarType::Regex(regex) => ValidGrammar::Regex(regex), + }; + Some(valid_grammar) } - None => (String::new(), ProtoGrammarType::None.into()), + None => None, }; - let parameters = NextTokenChooserParameters { + let parameters = ValidParameters { temperature, repetition_penalty, frequency_penalty, @@ -374,9 +366,8 @@ impl Validation { seed, watermark, grammar, - grammar_type, }; - let stopping_parameters = StoppingCriteriaParameters { + let stopping_parameters = ValidStoppingParameters { max_new_tokens, stop_sequences, ignore_eos_token: false, @@ -458,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option { _ => None, } } + fn format_to_mimetype(format: ImageFormat) -> String { match format { ImageFormat::Png => "image/png", @@ -636,14 +628,55 @@ type TokenizerRequest = ( Span, ); +#[derive(Debug, Clone)] +pub(crate) enum ValidGrammar { + Json(String), + Regex(String), +} + +#[derive(Debug, Clone)] +pub(crate) struct ValidParameters { + /// / exponential scaling output probability distribution + pub temperature: f32, + /// / restricting to the k highest probability elements + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + pub typical_p: f32, + /// / apply sampling on the logits + pub do_sample: bool, + /// / random seed for sampling + pub seed: u64, + /// / repetition penalty + pub repetition_penalty: f32, + /// / frequency penalty + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + pub watermark: bool, + /// / grammar (applied if not empty) + pub grammar: Option, +} + +#[derive(Debug, Clone)] +pub(crate) struct ValidStoppingParameters { + /// / Maximum number of generated tokens + pub max_new_tokens: u32, + /// / Optional stopping sequences + pub stop_sequences: Vec, + /// / Ignore end of sequence token + /// / used for benchmarking + pub ignore_eos_token: bool, +} + #[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, - pub parameters: NextTokenChooserParameters, - pub stopping_parameters: StoppingCriteriaParameters, + pub parameters: ValidParameters, + pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, }