small refactor to make router a bit more agnostic

This commit is contained in:
OlivierDehaene 2024-06-03 13:30:31 +02:00
parent df71aafdcc
commit 4dbb342fe3
16 changed files with 478 additions and 139 deletions

236
proto/v3/generate.proto Normal file
View File

@ -0,0 +1,236 @@
syntax = "proto3";
package generate.v3;
service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
message HealthResponse {}
/// Empty request
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
}
/// Empty request
message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse {
/// Other shards urls
repeated string urls = 1;
}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// tokens
repeated string texts = 3;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
}
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}

View File

@ -6,6 +6,7 @@ authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "^0.1"
base64 = { workspace = true }
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }

View File

@ -1,19 +1,31 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(());
println!("cargo:rerun-if-changed=../../proto/**");
fs::create_dir_all("src/v2/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/pb")
.out_dir("src/v2/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
fs::create_dir_all("src/v3/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v3/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

View File

@ -1,25 +1,32 @@
//! Text Generation gRPC client library
mod client;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod sharded_client;
pub mod v2;
pub mod v3;
use base64::{engine::general_purpose::STANDARD, Engine};
pub use client::Client;
pub use pb::generate::v2::input_chunk::Chunk;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::Image;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use async_trait::async_trait;
use thiserror::Error;
use tonic::transport;
use tonic::Status;
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]

View File

@ -1 +0,0 @@
*.rs

View File

@ -0,0 +1,13 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

1
router/client/src/v2/pb/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*

View File

@ -0,0 +1,13 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::HealthResponse;
pub use pb::generate::v3::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

1
router/client/src/v3/pb/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*

View File

@ -1,22 +1,18 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
use text_generation_client::Health;
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
#[derive(Clone)]
pub(crate) struct HealthCheck {
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
impl HealthCheck {
pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self {
client,
generation_health,
@ -24,52 +20,15 @@ impl Health {
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
// Generation is unhealthy or have not sent any generation request yet
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
self.client.model_health().await
}
.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}

View File

@ -1,9 +1,9 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text,
TextMessage, Token,
ChatTemplateInputs, ChatTemplateVersions, Entry, FinishReason, GenerateRequest,
GenerateStreamResponse, HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk,
PrefillToken, Queue, Text, TextMessage, Token,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
@ -15,9 +15,8 @@ use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
};
use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::{v2, ClientError};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
@ -232,16 +231,8 @@ impl Infer {
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
InferStreamResponse::Prefill(prefill_tokens) => {
result_prefill = prefill_tokens;
}
// Push last token
InferStreamResponse::Intermediate { token, top_tokens } => {
@ -792,6 +783,16 @@ fn send_responses(
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
@ -842,7 +843,7 @@ fn send_responses(
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: generated_text.clone(),
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
@ -877,10 +878,36 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
});
}
#[derive(Debug)]
pub(crate) struct GeneratedText {
pub(crate) text: String,
pub(crate) generated_tokens: u32,
pub(crate) finish_reason: FinishReason,
pub(crate) seed: Option<u64>,
}
impl From<v2::GeneratedText> for GeneratedText {
fn from(value: v2::GeneratedText) -> Self {
let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v2_finish_reason {
v2::FinishReason::Length => FinishReason::Length,
v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
v2::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(Tokens),
Prefill(Vec<PrefillToken>),
// Intermediate messages
Intermediate {
token: Token,
@ -1355,11 +1382,11 @@ mod tests {
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
input: ChatTemplateInputs {
messages: vec![
TextMessage{
TextMessage {
role: "system".to_string(),
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
},
TextMessage{
TextMessage {
role: "user".to_string(),
content: "How many helicopters can a human eat in one sitting?".to_string(),
},

View File

@ -1087,7 +1087,7 @@ pub struct SimpleToken {
stop: usize,
}
#[derive(Serialize, ToSchema)]
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason {

View File

@ -12,7 +12,7 @@ use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient};
use text_generation_client::{v2::ShardedClient, ClientError};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
use thiserror::Error;

View File

@ -1,12 +1,17 @@
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use text_generation_client::{Batch, Request};
use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -285,8 +290,12 @@ impl State {
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
)),
stopping_parameters: Some(StoppingCriteriaParameters::from(
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens,
});
// Set batch_time
@ -355,6 +364,43 @@ enum QueueCommand {
},
}
impl From<ValidParameters> for NextTokenChooserParameters {
fn from(value: ValidParameters) -> Self {
let (grammar, grammar_type) = match value.grammar {
None => (String::new(), GrammarType::None),
Some(grammar) => match grammar {
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
},
};
Self {
temperature: value.temperature,
top_k: value.top_k,
top_p: value.top_p,
typical_p: value.typical_p,
do_sample: value.do_sample,
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
}
}
}
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
fn from(value: ValidStoppingParameters) -> Self {
Self {
max_new_tokens: value.max_new_tokens,
stop_sequences: value.stop_sequences,
ignore_eos_token: value.ignore_eos_token,
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,6 +1,6 @@
use crate::config::Config;
/// HTTP Server logic
use crate::health::Health;
use crate::config::Config;
use crate::health::HealthCheck;
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError;
use crate::{
@ -34,7 +34,7 @@ use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use text_generation_client::{v2::ShardedClient, ShardInfo};
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
@ -115,7 +115,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)]
#[instrument(skip(health))]
/// Health check method
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
async fn health(
mut health: Extension<HealthCheck>,
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await {
true => Ok(()),
false => Err((
@ -1482,7 +1484,7 @@ pub async fn run(
grammar_support,
);
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let health_ext = HealthCheck::new(Arc::new(client.clone()), generation_health.clone());
let infer = Infer::new(
client,
validation,
@ -1719,17 +1721,6 @@ async fn shutdown_signal() {
opentelemetry::global::shutdown_tracer_provider();
}
impl From<i32> for FinishReason {
fn from(finish_reason: i32) -> Self {
let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
match finish_reason {
text_generation_client::FinishReason::Length => FinishReason::Length,
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
}
}
}
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {

View File

@ -6,10 +6,6 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use text_generation_client::{
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
StoppingCriteriaParameters,
};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
// use tokenizers::TruncationDirection;
@ -173,10 +169,6 @@ impl Validation {
// Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
input_length = input_length.saturating_sub(max_new_tokens as usize);
// return Err(ValidationError::MaxNewTokens(
// self.max_total_tokens - self.max_input_length,
// max_new_tokens,
// ));
}
Ok((
@ -327,13 +319,13 @@ impl Validation {
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar {
let grammar = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if self.disable_grammar_support {
return Err(ValidationError::Grammar);
}
match grammar {
let valid_grammar = match grammar {
GrammarType::Json(json) => {
let json = match json {
// if value is a string, we need to parse it again to make sure its
@ -350,20 +342,20 @@ impl Validation {
.compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(
// Serialize json to string
// Serialize json to string
ValidGrammar::Json(
serde_json::to_string(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
ProtoGrammarType::Json.into(),
)
}
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
};
Some(valid_grammar)
}
None => (String::new(), ProtoGrammarType::None.into()),
None => None,
};
let parameters = NextTokenChooserParameters {
let parameters = ValidParameters {
temperature,
repetition_penalty,
frequency_penalty,
@ -374,9 +366,8 @@ impl Validation {
seed,
watermark,
grammar,
grammar_type,
};
let stopping_parameters = StoppingCriteriaParameters {
let stopping_parameters = ValidStoppingParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false,
@ -458,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
_ => None,
}
}
fn format_to_mimetype(format: ImageFormat) -> String {
match format {
ImageFormat::Png => "image/png",
@ -636,14 +628,55 @@ type TokenizerRequest = (
Span,
);
#[derive(Debug, Clone)]
pub(crate) enum ValidGrammar {
Json(String),
Regex(String),
}
#[derive(Debug, Clone)]
pub(crate) struct ValidParameters {
/// / exponential scaling output probability distribution
pub temperature: f32,
/// / restricting to the k highest probability elements
pub top_k: u32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
pub top_p: f32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
pub typical_p: f32,
/// / apply sampling on the logits
pub do_sample: bool,
/// / random seed for sampling
pub seed: u64,
/// / repetition penalty
pub repetition_penalty: f32,
/// / frequency penalty
pub frequency_penalty: f32,
/// / token watermarking using "A Watermark for Large Language Models"
pub watermark: bool,
/// / grammar (applied if not empty)
pub grammar: Option<ValidGrammar>,
}
#[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters {
/// / Maximum number of generated tokens
pub max_new_tokens: u32,
/// / Optional stopping sequences
pub stop_sequences: Vec<String>,
/// / Ignore end of sequence token
/// / used for benchmarking
pub ignore_eos_token: bool,
}
#[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
}