This commit is contained in:
Nicolas Patry 2024-07-30 12:22:24 +02:00
parent ddbbf6b50c
commit bc0a33e1c9
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
17 changed files with 5305 additions and 185 deletions

2
.gitignore vendored
View File

@ -3,6 +3,8 @@ target
router/tokenizer.json router/tokenizer.json
*__pycache__* *__pycache__*
backends/v3/src/client/pb
# ROCm auto-generated files # ROCm auto-generated files
*.hip *.hip
server/exllamav2_kernels/exllamav2_kernels/hip/ server/exllamav2_kernels/exllamav2_kernels/hip/

4963
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -4,8 +4,9 @@ members = [
"backends/v3", "backends/v3",
# "backends/client", # "backends/client",
"backends/grpc-metadata", "backends/grpc-metadata",
# "backends/trtllm",
"launcher" "launcher"
, "backends/trtllm"] ]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
@ -18,6 +19,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
base64 = "0.22.0" base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] } tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
[profile.release] [profile.release]
incremental = true incremental = true

View File

@ -2,8 +2,8 @@ use std::future::Future;
use std::path::Path; use std::path::Path;
use std::pin::{pin, Pin}; use std::pin::{pin, Pin};
use std::str::FromStr; use std::str::FromStr;
use std::sync::{Arc, OnceLock};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
@ -13,15 +13,15 @@ use log::{error, warn};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio::time::{Instant, sleep}; use tokio::time::{sleep, Instant};
use tokio_stream::{Stream, StreamExt};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{instrument, Level, span}; use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};

View File

@ -24,8 +24,8 @@ grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = { workspace = true }
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"

View File

@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else { } else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
} }
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
@ -218,8 +220,8 @@ pub(crate) async fn batching_task(
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0); metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
} }
} }
} }
@ -232,7 +234,7 @@ async fn prefill(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -243,18 +245,22 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None None
} }
} }
@ -268,7 +274,7 @@ async fn decode(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -280,13 +286,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat { if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
} }
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -295,7 +306,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None None
} }
} }
@ -353,7 +364,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
@ -369,7 +380,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true); return Ok(true);
} }
@ -395,7 +406,7 @@ fn send_responses(
// Create last Token // Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
@ -460,7 +471,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -8,10 +8,10 @@ use tonic::Status;
#[allow(clippy::derive_partial_eq_without_eq)] #[allow(clippy::derive_partial_eq_without_eq)]
mod pb; mod pb;
mod client; mod grpc_client;
mod sharded_client; mod sharded_client;
pub use client::Client; pub use grpc_client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,

View File

@ -2,7 +2,7 @@ use crate::client::{ClientError, Result};
/// Multi shard Client /// Multi shard Client
use crate::client::{Health, ShardInfo}; use crate::client::{Health, ShardInfo};
use crate::client::client::{DecodeTimings, PrefillTimings}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{ use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, NextTokenChooserParameters, Request, StoppingCriteriaParameters,

View File

@ -30,6 +30,7 @@ pub struct BackendInfo {
pub max_batch_size: Option<usize>, pub max_batch_size: Option<usize>,
} }
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend( pub async fn connect_backend(
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,

View File

@ -1,4 +1,4 @@
use clap::Parser; use clap::{Parser, Subcommand};
use text_generation_router::server; use text_generation_router::server;
use text_generation_router_v3::{connect_backend, V3Error}; use text_generation_router_v3::{connect_backend, V3Error};
use thiserror::Error; use thiserror::Error;
@ -7,6 +7,9 @@ use thiserror::Error;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
@ -44,6 +47,8 @@ struct Args {
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
@ -65,12 +70,18 @@ struct Args {
max_client_batch_size: usize, max_client_batch_size: usize,
} }
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), RouterError> { async fn main() -> Result<(), RouterError> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
command,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
@ -89,6 +100,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_config_path, tokenizer_config_path,
revision, revision,
validation_workers, validation_workers,
api_key,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
otlp_service_name, otlp_service_name,
@ -101,8 +113,19 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size, max_client_batch_size,
} = args; } = args;
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
text_generation_router::logging::init_logging(
otlp_endpoint,
otlp_service_name,
json_output,
);
false
}
};
// Launch Tokio runtime // Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args // Validate args
if max_input_tokens >= max_total_tokens { if max_input_tokens >= max_total_tokens {
@ -151,6 +174,7 @@ async fn main() -> Result<(), RouterError> {
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
validation_workers, validation_workers,
api_key,
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
@ -163,6 +187,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
print_schema_command,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -17,8 +17,8 @@ futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10" itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.23.0" metrics = { workspace = true }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"
@ -48,6 +48,7 @@ base64 = { workspace = true }
sysinfo = "0.30.13" sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9"
[build-dependencies] [build-dependencies]

View File

@ -1,5 +1,7 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage}; use crate::{
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat; use minijinja_contrib::pycompat;
@ -19,8 +21,8 @@ pub(crate) struct ChatTemplate {
impl ChatTemplate { impl ChatTemplate {
pub(crate) fn new( pub(crate) fn new(
template: String, template: String,
bos_token: Option<String>, bos_token: Option<TokenizerConfigToken>,
eos_token: Option<String>, eos_token: Option<TokenizerConfigToken>,
) -> Self { ) -> Self {
let mut env = Box::new(Environment::new()); let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize() // enable things like .strip() or .capitalize()
@ -38,8 +40,8 @@ impl ChatTemplate {
Self { Self {
template, template,
bos_token, bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token, eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template, use_default_tool_template,
} }
} }
@ -52,9 +54,9 @@ impl ChatTemplate {
if self.use_default_tool_template { if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() { if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text { last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools), text: format!("\n---\n{}\n{}", tool_prompt, tools),
})); });
} }
} }
} }

View File

@ -1,46 +1,62 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools}; use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
pub(crate) struct ToolGrammar {} pub(crate) struct ToolGrammar {}
impl ToolGrammar { impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply( pub fn apply(
tools: Option<Vec<Tool>>, tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>, tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> { ) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { // if no tools are provided, we return None
// let tool_prompt = tool_prompt.unwrap_or_default(); let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => { ToolType::FunctionName(name) => {
vec![req_tools vec![Self::find_tool_by_name(&tools, &name)?]
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
} }
ToolType::OneOf => req_tools.to_owned(), ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
}; };
// adds the error notification function for LLM feedback if required // adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new(); let mut text_response_properties = Map::new();
text_response_properties.insert( text_response_properties.insert(
"error".to_string(), "error".to_string(),
json!({ serde_json::json!({
"type": "string", "type": "string",
"description": "The error or issue to notify" "description": "The error or issue to notify"
}), }),
); );
text_response_properties.insert( text_response_properties.insert(
"_name".to_string(), "_name".to_string(),
json!({ serde_json::json!({
"type": "string", "type": "string",
"const": "notify_error" "const": "notify_error"
}), }),
); );
let functions: HashMap<String, Value> = tools_to_use let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter() .iter()
.map(|tool| { .map(|tool| {
let func = tool.function.clone(); let func = tool.function.clone();
@ -91,7 +107,7 @@ impl ToolGrammar {
}) })
.chain([( .chain([(
"notify_error".to_string(), "notify_error".to_string(),
json!({ serde_json::json!({
"properties": text_response_properties, "properties": text_response_properties,
"required": ["error", "_name"], "required": ["error", "_name"],
"type": "object" "type": "object"
@ -114,9 +130,6 @@ impl ToolGrammar {
}, },
}; };
return Ok(Some(tools)); Ok(Some(tools))
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
} }
} }

View File

@ -35,9 +35,9 @@ use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use http::header::AUTHORIZATION;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
@ -46,6 +46,7 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use thiserror::Error; use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -1406,15 +1407,16 @@ pub async fn run(
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
api_key: Option<String>, api_key: Option<String>,
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
@ -1507,6 +1509,16 @@ pub async fn run(
println!("{}", api_doc); println!("{}", api_doc);
std::process::exit(0); std::process::exit(0);
} }
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token // Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN") let authorization_token = std::env::var("HF_TOKEN")
@ -1564,6 +1576,7 @@ pub async fn run(
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) = match api { ) = match api {
@ -1571,6 +1584,7 @@ pub async fn run(
Some(local_path.join("tokenizer.json")), Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None, None,
), ),
@ -1587,6 +1601,7 @@ pub async fn run(
}; };
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
@ -1599,6 +1614,7 @@ pub async fn run(
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) )
@ -1613,13 +1629,40 @@ pub async fn run(
repo.get("tokenizer.json"), repo.get("tokenizer.json"),
repo.get("config.json"), repo.get("config.json"),
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None, None,
) )
} }
}; };
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let config: Option<Config> = config_filename.and_then(|filename| { let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename) std::fs::read_to_string(filename)
.ok() .ok()
@ -1638,22 +1681,13 @@ pub async fn run(
pipeline_tag: None, pipeline_tag: None,
}); });
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let processor_config = processor_config_filename let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file) .and_then(HubProcessorConfig::from_file)
.unwrap_or_default(); .unwrap_or_default();
let preprocessor_config: Option<HubPreprocessorConfig> =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
tracing::info!("Using config {config:?}"); tracing::info!("Using config {config:?}");
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
@ -2107,3 +2141,74 @@ pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
} }
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}

View File

@ -5,13 +5,12 @@ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
}; };
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -181,11 +180,7 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok(( Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
} }
} }
@ -589,7 +584,7 @@ fn prepare_input(
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
@ -601,16 +596,16 @@ fn prepare_input(
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
@ -618,7 +613,7 @@ fn prepare_input(
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), _ => (inputs.clone(), vec![Chunk::Text(inputs)]),
}; };
// Get the number of tokens in the input // Get the number of tokens in the input
@ -784,8 +779,7 @@ pub enum ValidationError {
#[error("Could not fetch image: {0}")] #[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")] #[error("{0} modality is not supported")]
UnsupportedModality(&'static str) UnsupportedModality(&'static str),
} }
#[cfg(test)] #[cfg(test)]