mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
small refactor to make router a bit more agnostic
This commit is contained in:
parent
df71aafdcc
commit
4dbb342fe3
236
proto/v3/generate.proto
Normal file
236
proto/v3/generate.proto
Normal file
@ -0,0 +1,236 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package generate.v3;
|
||||
|
||||
service TextGenerationService {
|
||||
/// Model Info
|
||||
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||
/// Service discovery
|
||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Remove requests from a cached batch
|
||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||
/// Warmup the model and compute max cache size
|
||||
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
/// Health check
|
||||
rpc Health (HealthRequest) returns (HealthResponse);
|
||||
}
|
||||
|
||||
message HealthRequest {}
|
||||
message HealthResponse {}
|
||||
|
||||
/// Empty request
|
||||
message InfoRequest {}
|
||||
|
||||
message InfoResponse {
|
||||
bool requires_padding = 1;
|
||||
string dtype = 2;
|
||||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
uint32 speculate = 5;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
message ServiceDiscoveryRequest {}
|
||||
|
||||
message ServiceDiscoveryResponse {
|
||||
/// Other shards urls
|
||||
repeated string urls = 1;
|
||||
}
|
||||
|
||||
message ClearCacheRequest {
|
||||
/// Optional batch id
|
||||
optional uint64 id = 1;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
||||
enum GrammarType {
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
GRAMMAR_TYPE_JSON = 1;
|
||||
GRAMMAR_TYPE_REGEX = 2;
|
||||
}
|
||||
|
||||
message NextTokenChooserParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
/// restricting to the k highest probability elements
|
||||
uint32 top_k = 2;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float top_p = 3;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float typical_p = 4;
|
||||
/// apply sampling on the logits
|
||||
bool do_sample = 5;
|
||||
/// random seed for sampling
|
||||
uint64 seed = 6;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 7;
|
||||
/// frequency penalty
|
||||
float frequency_penalty = 9;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
/// grammar (applied if not empty)
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
/// Maximum number of generated tokens
|
||||
uint32 max_new_tokens = 1;
|
||||
/// Optional stopping sequences
|
||||
repeated string stop_sequences = 2;
|
||||
/// Ignore end of sequence token
|
||||
/// used for benchmarking
|
||||
bool ignore_eos_token = 3;
|
||||
}
|
||||
|
||||
message Request {
|
||||
/// Request ID
|
||||
uint64 id = 1;
|
||||
/// The generation context
|
||||
string inputs = 2;
|
||||
/// Context truncation
|
||||
uint32 truncate = 3;
|
||||
/// Next Token Chooser Parameters
|
||||
NextTokenChooserParameters parameters = 4;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests
|
||||
repeated Request requests = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
}
|
||||
|
||||
message CachedBatch {
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests ids
|
||||
repeated uint64 request_ids = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
FINISH_REASON_LENGTH = 0;
|
||||
FINISH_REASON_EOS_TOKEN = 1;
|
||||
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||
}
|
||||
|
||||
message GeneratedText {
|
||||
/// Output
|
||||
string text = 1;
|
||||
/// Number of generated tokens
|
||||
uint32 generated_tokens = 2;
|
||||
/// Finish reason
|
||||
FinishReason finish_reason = 3;
|
||||
/// Seed
|
||||
optional uint64 seed = 4;
|
||||
}
|
||||
|
||||
message Tokens {
|
||||
/// Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// tokens
|
||||
repeated string texts = 3;
|
||||
/// special
|
||||
repeated bool is_special = 4;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
/// Prefill tokens (optional)
|
||||
Tokens prefill_tokens = 2;
|
||||
Tokens tokens = 3;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 4;
|
||||
/// Top tokens
|
||||
repeated Tokens top_tokens = 5;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
/// Batch ID
|
||||
uint64 batch_id = 1;
|
||||
/// Requests to keep
|
||||
repeated uint64 request_ids = 2;
|
||||
}
|
||||
|
||||
message FilterBatchResponse {
|
||||
/// Filtered Batch (cached)
|
||||
CachedBatch batch = 1;
|
||||
}
|
||||
|
||||
|
||||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
}
|
||||
|
||||
message PrefillResponse {
|
||||
/// Generation
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
}
|
||||
|
||||
message DecodeRequest {
|
||||
/// Cached batches
|
||||
repeated CachedBatch batches = 1;
|
||||
}
|
||||
|
||||
message DecodeResponse {
|
||||
/// Decodes
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Concatenate elapsed time in nanoseconds
|
||||
optional uint64 concat_ns = 6;
|
||||
}
|
||||
|
||||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
uint32 max_input_length = 2;
|
||||
uint32 max_prefill_tokens = 3;
|
||||
uint32 max_total_tokens = 4;
|
||||
}
|
||||
|
||||
message WarmupResponse {
|
||||
/// Maximum number of tokens supported by the model
|
||||
optional uint32 max_supported_total_tokens = 1;
|
||||
}
|
@ -6,6 +6,7 @@ authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "^0.1"
|
||||
base64 = { workspace = true }
|
||||
futures = "^0.3"
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
|
@ -1,19 +1,31 @@
|
||||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/generate.proto");
|
||||
fs::create_dir("src/pb").unwrap_or(());
|
||||
println!("cargo:rerun-if-changed=../../proto/**");
|
||||
|
||||
fs::create_dir_all("src/v2/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/pb")
|
||||
.out_dir("src/v2/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
fs::create_dir_all("src/v3/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v3/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,25 +1,32 @@
|
||||
//! Text Generation gRPC client library
|
||||
|
||||
mod client;
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
mod sharded_client;
|
||||
pub mod v2;
|
||||
pub mod v3;
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v2::input_chunk::Chunk;
|
||||
pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::Image;
|
||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
|
1
router/client/src/pb/.gitignore
vendored
1
router/client/src/pb/.gitignore
vendored
@ -1 +0,0 @@
|
||||
*.rs
|
13
router/client/src/v2/mod.rs
Normal file
13
router/client/src/v2/mod.rs
Normal file
@ -0,0 +1,13 @@
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
1
router/client/src/v2/pb/.gitignore
vendored
Normal file
1
router/client/src/v2/pb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*
|
13
router/client/src/v3/mod.rs
Normal file
13
router/client/src/v3/mod.rs
Normal file
@ -0,0 +1,13 @@
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v3::HealthResponse;
|
||||
pub use pb::generate::v3::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
1
router/client/src/v3/pb/.gitignore
vendored
Normal file
1
router/client/src/v3/pb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*
|
@ -1,22 +1,18 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
|
||||
use text_generation_client::Health;
|
||||
|
||||
// Note: Request ids and batch ids cannot collide.
|
||||
const LIVENESS_ID: u64 = u64::MAX;
|
||||
const BATCH_ID: u64 = u64::MAX;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Health {
|
||||
client: ShardedClient,
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct HealthCheck {
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||
impl HealthCheck {
|
||||
pub(crate) fn new(
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
@ -24,52 +20,15 @@ impl Health {
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||
self.client.health().await.is_ok()
|
||||
let value = if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
// Generation is unhealthy or have not sent any generation request yet
|
||||
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: LIVENESS_ID,
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: BATCH_ID,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
// Skips the queue
|
||||
let value = self.client.prefill(batch).await.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text,
|
||||
TextMessage, Token,
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, FinishReason, GenerateRequest,
|
||||
GenerateStreamResponse, HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk,
|
||||
PrefillToken, Queue, Text, TextMessage, Token,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
@ -15,9 +15,8 @@ use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
|
||||
};
|
||||
use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
|
||||
use text_generation_client::{v2, ClientError};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||
@ -232,16 +231,8 @@ impl Infer {
|
||||
while let Some(response) = stream.next().await {
|
||||
match response? {
|
||||
// Add prefill tokens
|
||||
InferStreamResponse::Prefill(tokens) => {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
result_prefill = tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens.logprobs.into_iter())
|
||||
.zip(tokens.texts.into_iter())
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||
result_prefill = prefill_tokens;
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||
@ -792,6 +783,16 @@ fn send_responses(
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs.into_iter())
|
||||
.zip(prefill_tokens.texts.into_iter())
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
@ -842,7 +843,7 @@ fn send_responses(
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: generated_text.clone(),
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
@ -877,10 +878,36 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub(crate) text: String,
|
||||
pub(crate) generated_tokens: u32,
|
||||
pub(crate) finish_reason: FinishReason,
|
||||
pub(crate) seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl From<v2::GeneratedText> for GeneratedText {
|
||||
fn from(value: v2::GeneratedText) -> Self {
|
||||
let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v2_finish_reason {
|
||||
v2::FinishReason::Length => FinishReason::Length,
|
||||
v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
v2::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(Tokens),
|
||||
Prefill(Vec<PrefillToken>),
|
||||
// Intermediate messages
|
||||
Intermediate {
|
||||
token: Token,
|
||||
@ -1355,11 +1382,11 @@ mod tests {
|
||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||
input: ChatTemplateInputs {
|
||||
messages: vec![
|
||||
TextMessage{
|
||||
TextMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
||||
},
|
||||
TextMessage{
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
||||
},
|
||||
|
@ -1087,7 +1087,7 @@ pub struct SimpleToken {
|
||||
stop: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub(crate) enum FinishReason {
|
||||
|
@ -12,7 +12,7 @@ use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_client::{v2::ShardedClient, ClientError};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
|
@ -1,12 +1,17 @@
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::ChunksToString;
|
||||
use text_generation_client::Input;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use text_generation_client::v2::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
@ -285,8 +290,12 @@ impl State {
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
});
|
||||
// Set batch_time
|
||||
@ -355,6 +364,43 @@ enum QueueCommand {
|
||||
},
|
||||
}
|
||||
|
||||
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||
fn from(value: ValidParameters) -> Self {
|
||||
let (grammar, grammar_type) = match value.grammar {
|
||||
None => (String::new(), GrammarType::None),
|
||||
|
||||
Some(grammar) => match grammar {
|
||||
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||
},
|
||||
};
|
||||
|
||||
Self {
|
||||
temperature: value.temperature,
|
||||
top_k: value.top_k,
|
||||
top_p: value.top_p,
|
||||
typical_p: value.typical_p,
|
||||
do_sample: value.do_sample,
|
||||
seed: value.seed,
|
||||
repetition_penalty: value.repetition_penalty,
|
||||
frequency_penalty: value.frequency_penalty,
|
||||
watermark: value.watermark,
|
||||
grammar,
|
||||
grammar_type: grammar_type.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||
fn from(value: ValidStoppingParameters) -> Self {
|
||||
Self {
|
||||
max_new_tokens: value.max_new_tokens,
|
||||
stop_sequences: value.stop_sequences,
|
||||
ignore_eos_token: value.ignore_eos_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::config::Config;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::config::Config;
|
||||
use crate::health::HealthCheck;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
@ -34,7 +34,7 @@ use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use text_generation_client::{v2::ShardedClient, ShardInfo};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
use tokio::signal;
|
||||
@ -115,7 +115,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)]
|
||||
#[instrument(skip(health))]
|
||||
/// Health check method
|
||||
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
async fn health(
|
||||
mut health: Extension<HealthCheck>,
|
||||
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match health.check().await {
|
||||
true => Ok(()),
|
||||
false => Err((
|
||||
@ -1482,7 +1484,7 @@ pub async fn run(
|
||||
grammar_support,
|
||||
);
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||
let health_ext = HealthCheck::new(Arc::new(client.clone()), generation_health.clone());
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
@ -1719,17 +1721,6 @@ async fn shutdown_signal() {
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
}
|
||||
|
||||
impl From<i32> for FinishReason {
|
||||
fn from(finish_reason: i32) -> Self {
|
||||
let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
|
||||
match finish_reason {
|
||||
text_generation_client::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to Axum supported formats
|
||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
fn from(err: InferError) -> Self {
|
||||
|
@ -6,10 +6,6 @@ use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use text_generation_client::{
|
||||
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
// use tokenizers::TruncationDirection;
|
||||
@ -173,10 +169,6 @@ impl Validation {
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
// return Err(ValidationError::MaxNewTokens(
|
||||
// self.max_total_tokens - self.max_input_length,
|
||||
// max_new_tokens,
|
||||
// ));
|
||||
}
|
||||
|
||||
Ok((
|
||||
@ -327,13 +319,13 @@ impl Validation {
|
||||
// compiler and use that to build the FSM here.
|
||||
|
||||
// Validate grammar and unpack the grammar and type for the proto message
|
||||
let (grammar, grammar_type) = match grammar {
|
||||
let grammar = match grammar {
|
||||
Some(grammar) => {
|
||||
// Ensure that grammar is not set if it's not supported
|
||||
if self.disable_grammar_support {
|
||||
return Err(ValidationError::Grammar);
|
||||
}
|
||||
match grammar {
|
||||
let valid_grammar = match grammar {
|
||||
GrammarType::Json(json) => {
|
||||
let json = match json {
|
||||
// if value is a string, we need to parse it again to make sure its
|
||||
@ -350,20 +342,20 @@ impl Validation {
|
||||
.compile(&json)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
(
|
||||
// Serialize json to string
|
||||
// Serialize json to string
|
||||
ValidGrammar::Json(
|
||||
serde_json::to_string(&json)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||
ProtoGrammarType::Json.into(),
|
||||
)
|
||||
}
|
||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||
}
|
||||
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
||||
};
|
||||
Some(valid_grammar)
|
||||
}
|
||||
None => (String::new(), ProtoGrammarType::None.into()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
let parameters = ValidParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
@ -374,9 +366,8 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
grammar_type,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
let stopping_parameters = ValidStoppingParameters {
|
||||
max_new_tokens,
|
||||
stop_sequences,
|
||||
ignore_eos_token: false,
|
||||
@ -458,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_to_mimetype(format: ImageFormat) -> String {
|
||||
match format {
|
||||
ImageFormat::Png => "image/png",
|
||||
@ -636,14 +628,55 @@ type TokenizerRequest = (
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum ValidGrammar {
|
||||
Json(String),
|
||||
Regex(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
pub top_k: u32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
pub top_p: f32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
pub typical_p: f32,
|
||||
/// / apply sampling on the logits
|
||||
pub do_sample: bool,
|
||||
/// / random seed for sampling
|
||||
pub seed: u64,
|
||||
/// / repetition penalty
|
||||
pub repetition_penalty: f32,
|
||||
/// / frequency penalty
|
||||
pub frequency_penalty: f32,
|
||||
/// / token watermarking using "A Watermark for Large Language Models"
|
||||
pub watermark: bool,
|
||||
/// / grammar (applied if not empty)
|
||||
pub grammar: Option<ValidGrammar>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidStoppingParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// / Ignore end of sequence token
|
||||
/// / used for benchmarking
|
||||
pub ignore_eos_token: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: Vec<InputChunk>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
pub top_n_tokens: u32,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user