small refactor to make router a bit more agnostic

This commit is contained in:
OlivierDehaene 2024-06-03 13:30:31 +02:00
parent df71aafdcc
commit 4dbb342fe3
16 changed files with 478 additions and 139 deletions

236
proto/v3/generate.proto Normal file
View File

@ -0,0 +1,236 @@
syntax = "proto3";
package generate.v3;
service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
message HealthResponse {}
/// Empty request
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
}
/// Empty request
message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse {
/// Other shards urls
repeated string urls = 1;
}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// tokens
repeated string texts = 3;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
}
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}

View File

@ -6,6 +6,7 @@ authors.workspace = true
homepage.workspace = true homepage.workspace = true
[dependencies] [dependencies]
async-trait = "^0.1"
base64 = { workspace = true } base64 = { workspace = true }
futures = "^0.3" futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" } grpc-metadata = { path = "../grpc-metadata" }

View File

@ -1,19 +1,31 @@
use std::fs; use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto"); println!("cargo:rerun-if-changed=../../proto/**");
fs::create_dir("src/pb").unwrap_or(());
fs::create_dir_all("src/v2/pb").unwrap_or(());
let mut config = prost_build::Config::new(); let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional"); config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure() tonic_build::configure()
.build_client(true) .build_client(true)
.build_server(false) .build_server(false)
.out_dir("src/pb") .out_dir("src/v2/pb")
.include_file("mod.rs") .include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
fs::create_dir_all("src/v3/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v3/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(()) Ok(())
} }

View File

@ -1,25 +1,32 @@
//! Text Generation gRPC client library //! Text Generation gRPC client library
mod client; pub mod v2;
#[allow(clippy::derive_partial_eq_without_eq)] pub mod v3;
mod pb;
mod sharded_client;
use base64::{engine::general_purpose::STANDARD, Engine}; use async_trait::async_trait;
pub use client::Client;
pub use pb::generate::v2::input_chunk::Chunk;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::Image;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
use tonic::transport; use tonic::transport;
use tonic::Status; use tonic::Status;
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientError { pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")] #[error("Could not connect to Text Generation server: {0}")]

View File

@ -1 +0,0 @@
*.rs

View File

@ -0,0 +1,13 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

1
router/client/src/v2/pb/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*

View File

@ -0,0 +1,13 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::HealthResponse;
pub use pb::generate::v3::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

1
router/client/src/v3/pb/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*

View File

@ -1,22 +1,18 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::Health;
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
// Note: Request ids and batch ids cannot collide. #[derive(Clone)]
const LIVENESS_ID: u64 = u64::MAX; pub(crate) struct HealthCheck {
const BATCH_ID: u64 = u64::MAX; client: Arc<dyn Health + Send + Sync>,
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
} }
impl Health { impl HealthCheck {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self { pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self { Self {
client, client,
generation_health, generation_health,
@ -24,52 +20,15 @@ impl Health {
} }
pub(crate) async fn check(&mut self) -> bool { pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) { let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls // Generation is healthy, we only check that the shards can allocate on device
self.client.health().await.is_ok() self.client.device_health().await
} else { } else {
// Generation is unhealthy or have not sent any generation request yet self.client.model_health().await
}
// Dummy batch of 1 token and 1 generated token .is_ok();
let liveness_request = Request {
id: LIVENESS_ID,
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health // Update generation health
self.generation_health.store(value, Ordering::SeqCst); self.generation_health.store(value, Ordering::SeqCst);
value value
} }
} }
}

View File

@ -1,9 +1,9 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, ChatTemplateInputs, ChatTemplateVersions, Entry, FinishReason, GenerateRequest,
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, GenerateStreamResponse, HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk,
TextMessage, Token, PrefillToken, Queue, Text, TextMessage, Token,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
@ -15,9 +15,8 @@ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use text_generation_client::{ use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens, use text_generation_client::{v2, ClientError};
};
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
@ -232,16 +231,8 @@ impl Infer {
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response? { match response? {
// Add prefill tokens // Add prefill tokens
InferStreamResponse::Prefill(tokens) => { InferStreamResponse::Prefill(prefill_tokens) => {
// Create Token objects result_prefill = prefill_tokens;
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
} }
// Push last token // Push last token
InferStreamResponse::Intermediate { token, top_tokens } => { InferStreamResponse::Intermediate { token, top_tokens } => {
@ -792,6 +783,16 @@ fn send_responses(
let mut stopped = false; let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message // Send message
entry entry
.response_tx .response_tx
@ -842,7 +843,7 @@ fn send_responses(
entry.response_tx.send(Ok(InferStreamResponse::End { entry.response_tx.send(Ok(InferStreamResponse::End {
token, token,
top_tokens, top_tokens,
generated_text: generated_text.clone(), generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
}))?; }))?;
@ -877,10 +878,36 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
}); });
} }
#[derive(Debug)]
pub(crate) struct GeneratedText {
pub(crate) text: String,
pub(crate) generated_tokens: u32,
pub(crate) finish_reason: FinishReason,
pub(crate) seed: Option<u64>,
}
impl From<v2::GeneratedText> for GeneratedText {
fn from(value: v2::GeneratedText) -> Self {
let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v2_finish_reason {
v2::FinishReason::Length => FinishReason::Length,
v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
v2::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(Tokens), Prefill(Vec<PrefillToken>),
// Intermediate messages // Intermediate messages
Intermediate { Intermediate {
token: Token, token: Token,

View File

@ -1087,7 +1087,7 @@ pub struct SimpleToken {
stop: usize, stop: usize,
} }
#[derive(Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")] #[schema(example = "Length")]
pub(crate) enum FinishReason { pub(crate) enum FinishReason {

View File

@ -12,7 +12,7 @@ use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{v2::ShardedClient, ClientError};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
use thiserror::Error; use thiserror::Error;

View File

@ -1,12 +1,17 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::ChunksToString; use text_generation_client::ChunksToString;
use text_generation_client::Input; use text_generation_client::Input;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Span};
@ -285,8 +290,12 @@ impl State {
}), }),
inputs: entry.request.inputs.chunks_to_string(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(NextTokenChooserParameters::from(
stopping_parameters: Some(entry.request.stopping_parameters.clone()), entry.request.parameters.clone(),
)),
stopping_parameters: Some(StoppingCriteriaParameters::from(
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
}); });
// Set batch_time // Set batch_time
@ -355,6 +364,43 @@ enum QueueCommand {
}, },
} }
impl From<ValidParameters> for NextTokenChooserParameters {
fn from(value: ValidParameters) -> Self {
let (grammar, grammar_type) = match value.grammar {
None => (String::new(), GrammarType::None),
Some(grammar) => match grammar {
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
},
};
Self {
temperature: value.temperature,
top_k: value.top_k,
top_p: value.top_p,
typical_p: value.typical_p,
do_sample: value.do_sample,
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
}
}
}
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
fn from(value: ValidStoppingParameters) -> Self {
Self {
max_new_tokens: value.max_new_tokens,
stop_sequences: value.stop_sequences,
ignore_eos_token: value.ignore_eos_token,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,6 +1,6 @@
use crate::config::Config;
/// HTTP Server logic /// HTTP Server logic
use crate::health::Health; use crate::config::Config;
use crate::health::HealthCheck;
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
@ -34,7 +34,7 @@ use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{v2::ShardedClient, ShardInfo};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -115,7 +115,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)] )]
#[instrument(skip(health))] #[instrument(skip(health))]
/// Health check method /// Health check method
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(
mut health: Extension<HealthCheck>,
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await { match health.check().await {
true => Ok(()), true => Ok(()),
false => Err(( false => Err((
@ -1482,7 +1484,7 @@ pub async fn run(
grammar_support, grammar_support,
); );
let generation_health = Arc::new(AtomicBool::new(false)); let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone()); let health_ext = HealthCheck::new(Arc::new(client.clone()), generation_health.clone());
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
@ -1719,17 +1721,6 @@ async fn shutdown_signal() {
opentelemetry::global::shutdown_tracer_provider(); opentelemetry::global::shutdown_tracer_provider();
} }
impl From<i32> for FinishReason {
fn from(finish_reason: i32) -> Self {
let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
match finish_reason {
text_generation_client::FinishReason::Length => FinishReason::Length,
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
}
}
}
/// Convert to Axum supported formats /// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {

View File

@ -6,10 +6,6 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use text_generation_client::{
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
StoppingCriteriaParameters,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
// use tokenizers::TruncationDirection; // use tokenizers::TruncationDirection;
@ -173,10 +169,6 @@ impl Validation {
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
// return Err(ValidationError::MaxNewTokens(
// self.max_total_tokens - self.max_input_length,
// max_new_tokens,
// ));
} }
Ok(( Ok((
@ -327,13 +319,13 @@ impl Validation {
// compiler and use that to build the FSM here. // compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message // Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar { let grammar = match grammar {
Some(grammar) => { Some(grammar) => {
// Ensure that grammar is not set if it's not supported // Ensure that grammar is not set if it's not supported
if self.disable_grammar_support { if self.disable_grammar_support {
return Err(ValidationError::Grammar); return Err(ValidationError::Grammar);
} }
match grammar { let valid_grammar = match grammar {
GrammarType::Json(json) => { GrammarType::Json(json) => {
let json = match json { let json = match json {
// if value is a string, we need to parse it again to make sure its // if value is a string, we need to parse it again to make sure its
@ -350,20 +342,20 @@ impl Validation {
.compile(&json) .compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(
// Serialize json to string // Serialize json to string
ValidGrammar::Json(
serde_json::to_string(&json) serde_json::to_string(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
ProtoGrammarType::Json.into(),
) )
} }
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
};
Some(valid_grammar)
} }
} None => None,
None => (String::new(), ProtoGrammarType::None.into()),
}; };
let parameters = NextTokenChooserParameters { let parameters = ValidParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty, frequency_penalty,
@ -374,9 +366,8 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
grammar_type,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = ValidStoppingParameters {
max_new_tokens, max_new_tokens,
stop_sequences, stop_sequences,
ignore_eos_token: false, ignore_eos_token: false,
@ -458,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
_ => None, _ => None,
} }
} }
fn format_to_mimetype(format: ImageFormat) -> String { fn format_to_mimetype(format: ImageFormat) -> String {
match format { match format {
ImageFormat::Png => "image/png", ImageFormat::Png => "image/png",
@ -636,14 +628,55 @@ type TokenizerRequest = (
Span, Span,
); );
#[derive(Debug, Clone)]
pub(crate) enum ValidGrammar {
Json(String),
Regex(String),
}
#[derive(Debug, Clone)]
pub(crate) struct ValidParameters {
/// / exponential scaling output probability distribution
pub temperature: f32,
/// / restricting to the k highest probability elements
pub top_k: u32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
pub top_p: f32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
pub typical_p: f32,
/// / apply sampling on the logits
pub do_sample: bool,
/// / random seed for sampling
pub seed: u64,
/// / repetition penalty
pub repetition_penalty: f32,
/// / frequency penalty
pub frequency_penalty: f32,
/// / token watermarking using "A Watermark for Large Language Models"
pub watermark: bool,
/// / grammar (applied if not empty)
pub grammar: Option<ValidGrammar>,
}
#[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters {
/// / Maximum number of generated tokens
pub max_new_tokens: u32,
/// / Optional stopping sequences
pub stop_sequences: Vec<String>,
/// / Ignore end of sequence token
/// / used for benchmarking
pub ignore_eos_token: bool,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>, pub inputs: Vec<InputChunk>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: ValidParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32, pub top_n_tokens: u32,
} }