mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
continue refactoring
This commit is contained in:
parent
ba59da1589
commit
dc07ad2691
@ -1,5 +1,599 @@
|
|||||||
mod health;
|
mod health;
|
||||||
pub(crate) mod v2;
|
pub(crate) mod v2;
|
||||||
pub(crate) mod v3;
|
// pub(crate) mod v3;
|
||||||
|
|
||||||
pub(crate) use health::HealthCheck;
|
pub(crate) use health::HealthCheck;
|
||||||
|
|
||||||
|
use crate::validation::{Validation, ValidationError, ValidGenerateRequest};
|
||||||
|
use crate::{
|
||||||
|
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest,
|
||||||
|
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk,
|
||||||
|
PrefillToken, Text, TextMessage, Token,
|
||||||
|
};
|
||||||
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
|
use futures::future::try_join_all;
|
||||||
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
use text_generation_client::v2::{ShardedClient};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use tracing::{instrument, Span};
|
||||||
|
|
||||||
|
/// Queue entry
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Entry {
|
||||||
|
/// Request
|
||||||
|
pub request: ValidGenerateRequest,
|
||||||
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
|
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
|
/// Span that will live as long as entry
|
||||||
|
pub span: Span,
|
||||||
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
|
pub temp_span: Option<Span>,
|
||||||
|
/// Instant when this entry was queued
|
||||||
|
pub queue_time: Instant,
|
||||||
|
/// Instant when this entry was added to a batch
|
||||||
|
pub batch_time: Option<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait InferQueue {
|
||||||
|
/// Append an entry to the queue
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn append(&self, entry: Entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Inference struct
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Infer {
|
||||||
|
/// Validation
|
||||||
|
validation: Validation,
|
||||||
|
/// Request queue
|
||||||
|
queue: Arc<dyn InferQueue + Send + Sync>,
|
||||||
|
/// Notify batcher on queue appends
|
||||||
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
/// Chat template
|
||||||
|
chat_template: Option<ChatTemplate>,
|
||||||
|
/// Inference limit
|
||||||
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
impl Infer {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
validation: Validation,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
|
tokenizer_config: HubTokenizerConfig,
|
||||||
|
processor_config: HubProcessorConfig,
|
||||||
|
) -> Self {
|
||||||
|
let queue = v2::Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(v2::batching_task(
|
||||||
|
client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
queue.clone(),
|
||||||
|
batching_task_notifier.clone(),
|
||||||
|
generation_health,
|
||||||
|
));
|
||||||
|
|
||||||
|
let chat_template = tokenizer_config
|
||||||
|
.chat_template
|
||||||
|
.or(processor_config.chat_template)
|
||||||
|
.and_then(|t| match t {
|
||||||
|
ChatTemplateVersions::Single(template) => Some(template),
|
||||||
|
ChatTemplateVersions::Multiple(templates) => templates
|
||||||
|
.into_iter()
|
||||||
|
.find(|t| t.name == "default")
|
||||||
|
.map(|t| t.template),
|
||||||
|
})
|
||||||
|
.map(|t| {
|
||||||
|
// .strip() is not supported in minijinja
|
||||||
|
// .capitalize() is not supported in minijinja but we can use | capitalize
|
||||||
|
let t = t
|
||||||
|
.replace(".strip()", " | trim")
|
||||||
|
.replace(".capitalize()", " | capitalize");
|
||||||
|
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Inference limit with a semaphore
|
||||||
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
validation,
|
||||||
|
queue: Arc::new(queue),
|
||||||
|
batching_task_notifier,
|
||||||
|
chat_template,
|
||||||
|
limit_concurrent_requests: semaphore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let permit = self
|
||||||
|
.clone()
|
||||||
|
.limit_concurrent_requests
|
||||||
|
.try_acquire_owned()
|
||||||
|
.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
let input_length = valid_request.input_length;
|
||||||
|
|
||||||
|
// Append the request to the queue
|
||||||
|
self.queue.append(Entry {
|
||||||
|
request: valid_request,
|
||||||
|
response_tx,
|
||||||
|
span: Span::current(),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
// to be batched
|
||||||
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok((
|
||||||
|
permit,
|
||||||
|
input_length,
|
||||||
|
UnboundedReceiverStream::new(response_rx),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenizer the input
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn tokenize(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||||
|
// Tokenize request
|
||||||
|
let inputs = request.inputs;
|
||||||
|
let truncate = request.parameters.truncate;
|
||||||
|
let encoding = self
|
||||||
|
.validation
|
||||||
|
.tokenize(inputs, truncate)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("Tokenization {err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Return Encoding
|
||||||
|
Ok(encoding.map(|(encoding, _)| encoding))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply the chat template to the chat request
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn apply_chat_template(
|
||||||
|
&self,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||||
|
) -> Result<String, InferError> {
|
||||||
|
self.chat_template
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
|
.apply(messages, grammar_with_prompt)
|
||||||
|
.map_err(|e| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
|
tracing::error!("{e}");
|
||||||
|
e
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the queue and return a InferResponse
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn generate(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<InferResponse, InferError> {
|
||||||
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
|
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
|
// Return values
|
||||||
|
let mut result_prefill = Vec::new();
|
||||||
|
let mut result_tokens = Vec::new();
|
||||||
|
let mut result_top_tokens = Vec::new();
|
||||||
|
let mut result_generated_text = None;
|
||||||
|
let mut result_start = None;
|
||||||
|
let mut result_queued = None;
|
||||||
|
|
||||||
|
// Iterate on stream
|
||||||
|
while let Some(response) = stream.next().await {
|
||||||
|
match response? {
|
||||||
|
// Add prefill tokens
|
||||||
|
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||||
|
result_prefill = prefill_tokens;
|
||||||
|
}
|
||||||
|
// Push last token
|
||||||
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
|
}
|
||||||
|
// Final message
|
||||||
|
// Set return values
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
generated_text,
|
||||||
|
start,
|
||||||
|
queued,
|
||||||
|
top_tokens,
|
||||||
|
} => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
|
result_generated_text = Some(generated_text);
|
||||||
|
result_start = Some(start);
|
||||||
|
result_queued = Some(queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we received a `InferStreamResponse::End` message
|
||||||
|
if let (Some(generated_text), Some(queued), Some(start)) =
|
||||||
|
(result_generated_text, result_queued, result_start)
|
||||||
|
{
|
||||||
|
Ok(InferResponse {
|
||||||
|
prefill: result_prefill,
|
||||||
|
_input_length,
|
||||||
|
tokens: result_tokens,
|
||||||
|
generated_text,
|
||||||
|
queued,
|
||||||
|
start,
|
||||||
|
top_tokens: if use_top_tokens {
|
||||||
|
result_top_tokens
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
||||||
|
/// the highest log probability per token
|
||||||
|
#[instrument(skip(self, request))]
|
||||||
|
pub(crate) async fn generate_best_of(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
best_of: usize,
|
||||||
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||||
|
// validate best_of parameter separately
|
||||||
|
let best_of = self.validation.validate_best_of(best_of)?;
|
||||||
|
|
||||||
|
// create multiple generate requests
|
||||||
|
let mut infer_responses: Vec<InferResponse> =
|
||||||
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||||
|
|
||||||
|
// get the sequence with the highest log probability per token
|
||||||
|
let mut max_index = 0;
|
||||||
|
let mut max_logprob: f32 = f32::MIN;
|
||||||
|
|
||||||
|
for (i, response) in infer_responses.iter().enumerate() {
|
||||||
|
// mean logprobs of the generated tokens
|
||||||
|
let sequence_logprob = response
|
||||||
|
.tokens
|
||||||
|
.iter()
|
||||||
|
.map(|token| token.logprob)
|
||||||
|
.sum::<f32>()
|
||||||
|
/ response.tokens.len() as f32;
|
||||||
|
|
||||||
|
// set best sequence
|
||||||
|
if sequence_logprob > max_logprob {
|
||||||
|
max_index = i;
|
||||||
|
max_logprob = sequence_logprob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let best_response = infer_responses.remove(max_index);
|
||||||
|
Ok((best_response, infer_responses))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
|
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
|
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ChatTemplate {
|
||||||
|
template: Template<'static, 'static>,
|
||||||
|
bos_token: Option<String>,
|
||||||
|
eos_token: Option<String>,
|
||||||
|
use_default_tool_template: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatTemplate {
|
||||||
|
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
|
let mut env = Box::new(Environment::new());
|
||||||
|
let template_str = template.into_boxed_str();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
// check if contains the tools variable within the template
|
||||||
|
let use_default_tool_template =
|
||||||
|
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||||
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
|
let template = Box::leak(env)
|
||||||
|
.template_from_str(Box::leak(template_str))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
template,
|
||||||
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
use_default_tool_template,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(
|
||||||
|
&self,
|
||||||
|
mut messages: Vec<Message>,
|
||||||
|
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||||
|
) -> Result<String, InferError> {
|
||||||
|
if self.use_default_tool_template {
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
|
last_message.content.push(MessageChunk::Text(Text {
|
||||||
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
|
self.template
|
||||||
|
.render(ChatTemplateInputs {
|
||||||
|
messages,
|
||||||
|
bos_token: self.bos_token.as_deref(),
|
||||||
|
eos_token: self.eos_token.as_deref(),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
tools: None,
|
||||||
|
tools_prompt: None,
|
||||||
|
})
|
||||||
|
.map_err(InferError::TemplateError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ToolGrammar {}
|
||||||
|
|
||||||
|
impl ToolGrammar {
|
||||||
|
pub fn apply(
|
||||||
|
tools: Option<Vec<Tool>>,
|
||||||
|
tool_choice: Option<ToolType>,
|
||||||
|
) -> Result<Option<Tools>, InferError> {
|
||||||
|
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||||
|
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
|
let tools_to_use = match tool_choice {
|
||||||
|
ToolType::FunctionName(name) => {
|
||||||
|
vec![req_tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == *name)
|
||||||
|
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||||
|
.clone()]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => req_tools.to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// adds the error notification function for LLM feedback if required
|
||||||
|
let mut text_response_properties = Map::new();
|
||||||
|
text_response_properties.insert(
|
||||||
|
"error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
text_response_properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": "notify_error"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
let func = tool.function.clone();
|
||||||
|
|
||||||
|
// Clone the existing parameters, which are expected to be a JSON object
|
||||||
|
let mut params = if let Value::Object(params) = &func.arguments {
|
||||||
|
params.clone()
|
||||||
|
} else {
|
||||||
|
Map::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Insert the function's description at the top level, outside of properties
|
||||||
|
params.insert(
|
||||||
|
"description".to_string(),
|
||||||
|
Value::String(func.description.clone().unwrap_or_default()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ensure 'properties' exists and is an object
|
||||||
|
let properties = params
|
||||||
|
.entry("properties".to_string())
|
||||||
|
.or_insert_with(|| json!({}))
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Insert the constant for the function name inside 'properties'
|
||||||
|
properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": func.name.clone(),
|
||||||
|
// "description": "The name of the function"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||||
|
let required = params
|
||||||
|
.entry("required".to_string())
|
||||||
|
.or_insert_with(|| json!([]))
|
||||||
|
.as_array_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Add 'name' to the 'required' array if it is not already present
|
||||||
|
if !required.iter().any(|r| r == "_name") {
|
||||||
|
required.push(json!("_name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
(func.name, Value::Object(params))
|
||||||
|
})
|
||||||
|
.chain([(
|
||||||
|
"notify_error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"properties": text_response_properties,
|
||||||
|
"required": ["error", "_name"],
|
||||||
|
"type": "object"
|
||||||
|
}),
|
||||||
|
)])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
functions_map: FunctionsMap { functions },
|
||||||
|
properties: Properties {
|
||||||
|
function: tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.chain(std::iter::once(FunctionRef {
|
||||||
|
ref_path: "#/$functions/notify_error".to_string(),
|
||||||
|
}))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok(Some(tools));
|
||||||
|
}
|
||||||
|
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Type alias for generation responses
|
||||||
|
pub(crate) type GenerateStreamResponse = (
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct GeneratedText {
|
||||||
|
pub(crate) text: String,
|
||||||
|
pub(crate) generated_tokens: u32,
|
||||||
|
pub(crate) finish_reason: FinishReason,
|
||||||
|
pub(crate) seed: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum InferStreamResponse {
|
||||||
|
// Optional first message
|
||||||
|
Prefill(Vec<PrefillToken>),
|
||||||
|
// Intermediate messages
|
||||||
|
Intermediate {
|
||||||
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
|
},
|
||||||
|
// Last message
|
||||||
|
End {
|
||||||
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
|
generated_text: GeneratedText,
|
||||||
|
start: Instant,
|
||||||
|
queued: Instant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct InferResponse {
|
||||||
|
/// input_length is the input as perceived by the rust tokenizer in the
|
||||||
|
/// validation pathway. It is redundant with prefill.len() but prefill
|
||||||
|
/// has data only if the user asked for it. This will always be filled.
|
||||||
|
pub(crate) _input_length: u32,
|
||||||
|
pub(crate) prefill: Vec<PrefillToken>,
|
||||||
|
pub(crate) tokens: Vec<Token>,
|
||||||
|
pub(crate) generated_text: GeneratedText,
|
||||||
|
pub(crate) queued: Instant,
|
||||||
|
pub(crate) start: Instant,
|
||||||
|
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InferError {
|
||||||
|
#[error("Request failed during generation: {0}")]
|
||||||
|
GenerationError(String),
|
||||||
|
#[error("Model is overloaded")]
|
||||||
|
Overloaded(#[from] TryAcquireError),
|
||||||
|
#[error("Input validation error: {0}")]
|
||||||
|
ValidationError(#[from] ValidationError),
|
||||||
|
#[error("Incomplete generation")]
|
||||||
|
IncompleteGeneration,
|
||||||
|
#[error("Template error: {0}")]
|
||||||
|
TemplateError(#[from] minijinja::Error),
|
||||||
|
#[error("Tool error: {0}")]
|
||||||
|
ToolError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InferError {
|
||||||
|
pub(crate) fn error_type(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
InferError::GenerationError(_) => "generation",
|
||||||
|
InferError::Overloaded(_) => "overloaded",
|
||||||
|
InferError::ValidationError(_) => "validation",
|
||||||
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
|
InferError::TemplateError(_) => "template_error",
|
||||||
|
InferError::ToolError(_) => "tool_error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,511 +1,27 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
|
|
||||||
use crate::infer::v2::{Queue, Entry};
|
use crate::infer::Entry;
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::infer::v2::{Queue};
|
||||||
use crate::{
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest,
|
|
||||||
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk,
|
|
||||||
PrefillToken, Text, TextMessage, Token,
|
|
||||||
};
|
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
|
||||||
use futures::future::try_join_all;
|
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use serde_json::{json, Map, Value};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
|
use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
|
||||||
use text_generation_client::{v2, ClientError};
|
use text_generation_client::{ClientError};
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::Notify;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tracing::{info_span, instrument, Instrument};
|
||||||
use tokio_stream::StreamExt;
|
use crate::infer::{GeneratedText, InferError, InferStreamResponse};
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
|
||||||
|
|
||||||
/// Inference struct
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Infer {
|
|
||||||
/// Validation
|
|
||||||
validation: Validation,
|
|
||||||
/// Request queue
|
|
||||||
queue: Queue,
|
|
||||||
/// Shared state
|
|
||||||
shared: Arc<Shared>,
|
|
||||||
/// Chat template
|
|
||||||
chat_template: Option<ChatTemplate>,
|
|
||||||
/// Inference limit
|
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Infer shared state
|
|
||||||
struct Shared {
|
|
||||||
/// Batching background Tokio task notifier
|
|
||||||
batching_task: Notify,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Raise a exception (custom function) used in the chat templates
|
|
||||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
|
||||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Infer {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub(crate) fn new(
|
|
||||||
client: ShardedClient,
|
|
||||||
validation: Validation,
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
requires_padding: bool,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
tokenizer_config: HubTokenizerConfig,
|
|
||||||
processor_config: HubProcessorConfig,
|
|
||||||
) -> Self {
|
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
|
||||||
let shared = Arc::new(Shared {
|
|
||||||
batching_task: Notify::new(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
|
||||||
tokio::spawn(batching_task(
|
|
||||||
client,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
queue.clone(),
|
|
||||||
shared.clone(),
|
|
||||||
generation_health,
|
|
||||||
));
|
|
||||||
|
|
||||||
let chat_template = tokenizer_config
|
|
||||||
.chat_template
|
|
||||||
.or(processor_config.chat_template)
|
|
||||||
.and_then(|t| match t {
|
|
||||||
ChatTemplateVersions::Single(template) => Some(template),
|
|
||||||
ChatTemplateVersions::Multiple(templates) => templates
|
|
||||||
.into_iter()
|
|
||||||
.find(|t| t.name == "default")
|
|
||||||
.map(|t| t.template),
|
|
||||||
})
|
|
||||||
.map(|t| {
|
|
||||||
// .strip() is not supported in minijinja
|
|
||||||
// .capitalize() is not supported in minijinja but we can use | capitalize
|
|
||||||
let t = t
|
|
||||||
.replace(".strip()", " | trim")
|
|
||||||
.replace(".capitalize()", " | capitalize");
|
|
||||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
|
||||||
|
|
||||||
Self {
|
|
||||||
validation,
|
|
||||||
queue,
|
|
||||||
shared,
|
|
||||||
chat_template,
|
|
||||||
limit_concurrent_requests: semaphore,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let permit = self
|
|
||||||
.clone()
|
|
||||||
.limit_concurrent_requests
|
|
||||||
.try_acquire_owned()
|
|
||||||
.map_err(|err| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
|
||||||
let input_length = valid_request.input_length;
|
|
||||||
|
|
||||||
// Append the request to the queue
|
|
||||||
self.queue.append(Entry {
|
|
||||||
request: valid_request,
|
|
||||||
response_tx,
|
|
||||||
span: Span::current(),
|
|
||||||
temp_span: None,
|
|
||||||
queue_time: Instant::now(),
|
|
||||||
batch_time: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
|
||||||
// to be batched
|
|
||||||
self.shared.batching_task.notify_one();
|
|
||||||
|
|
||||||
// Return stream
|
|
||||||
Ok((
|
|
||||||
permit,
|
|
||||||
input_length,
|
|
||||||
UnboundedReceiverStream::new(response_rx),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tokenizer the input
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) async fn tokenize(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
|
||||||
// Tokenize request
|
|
||||||
let inputs = request.inputs;
|
|
||||||
let truncate = request.parameters.truncate;
|
|
||||||
let encoding = self
|
|
||||||
.validation
|
|
||||||
.tokenize(inputs, truncate)
|
|
||||||
.await
|
|
||||||
.map_err(|err| {
|
|
||||||
tracing::error!("Tokenization {err}");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Return Encoding
|
|
||||||
Ok(encoding.map(|(encoding, _)| encoding))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) fn apply_chat_template(
|
|
||||||
&self,
|
|
||||||
messages: Vec<Message>,
|
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
self.chat_template
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
|
||||||
.apply(messages, grammar_with_prompt)
|
|
||||||
.map_err(|e| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
|
||||||
tracing::error!("{e}");
|
|
||||||
e
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) async fn generate(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
) -> Result<InferResponse, InferError> {
|
|
||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
|
||||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
|
||||||
|
|
||||||
// Return values
|
|
||||||
let mut result_prefill = Vec::new();
|
|
||||||
let mut result_tokens = Vec::new();
|
|
||||||
let mut result_top_tokens = Vec::new();
|
|
||||||
let mut result_generated_text = None;
|
|
||||||
let mut result_start = None;
|
|
||||||
let mut result_queued = None;
|
|
||||||
|
|
||||||
// Iterate on stream
|
|
||||||
while let Some(response) = stream.next().await {
|
|
||||||
match response? {
|
|
||||||
// Add prefill tokens
|
|
||||||
InferStreamResponse::Prefill(prefill_tokens) => {
|
|
||||||
result_prefill = prefill_tokens;
|
|
||||||
}
|
|
||||||
// Push last token
|
|
||||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
|
||||||
result_tokens.push(token);
|
|
||||||
result_top_tokens.push(top_tokens);
|
|
||||||
}
|
|
||||||
// Final message
|
|
||||||
// Set return values
|
|
||||||
InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
generated_text,
|
|
||||||
start,
|
|
||||||
queued,
|
|
||||||
top_tokens,
|
|
||||||
} => {
|
|
||||||
result_tokens.push(token);
|
|
||||||
result_top_tokens.push(top_tokens);
|
|
||||||
result_generated_text = Some(generated_text);
|
|
||||||
result_start = Some(start);
|
|
||||||
result_queued = Some(queued)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that we received a `InferStreamResponse::End` message
|
|
||||||
if let (Some(generated_text), Some(queued), Some(start)) =
|
|
||||||
(result_generated_text, result_queued, result_start)
|
|
||||||
{
|
|
||||||
Ok(InferResponse {
|
|
||||||
prefill: result_prefill,
|
|
||||||
_input_length,
|
|
||||||
tokens: result_tokens,
|
|
||||||
generated_text,
|
|
||||||
queued,
|
|
||||||
start,
|
|
||||||
top_tokens: if use_top_tokens {
|
|
||||||
result_top_tokens
|
|
||||||
} else {
|
|
||||||
Vec::new()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
let err = InferError::IncompleteGeneration;
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
Err(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
|
||||||
/// the highest log probability per token
|
|
||||||
#[instrument(skip(self, request))]
|
|
||||||
pub(crate) async fn generate_best_of(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
best_of: usize,
|
|
||||||
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
|
||||||
// validate best_of parameter separately
|
|
||||||
let best_of = self.validation.validate_best_of(best_of)?;
|
|
||||||
|
|
||||||
// create multiple generate requests
|
|
||||||
let mut infer_responses: Vec<InferResponse> =
|
|
||||||
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
|
||||||
|
|
||||||
// get the sequence with the highest log probability per token
|
|
||||||
let mut max_index = 0;
|
|
||||||
let mut max_logprob: f32 = f32::MIN;
|
|
||||||
|
|
||||||
for (i, response) in infer_responses.iter().enumerate() {
|
|
||||||
// mean logprobs of the generated tokens
|
|
||||||
let sequence_logprob = response
|
|
||||||
.tokens
|
|
||||||
.iter()
|
|
||||||
.map(|token| token.logprob)
|
|
||||||
.sum::<f32>()
|
|
||||||
/ response.tokens.len() as f32;
|
|
||||||
|
|
||||||
// set best sequence
|
|
||||||
if sequence_logprob > max_logprob {
|
|
||||||
max_index = i;
|
|
||||||
max_logprob = sequence_logprob;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let best_response = infer_responses.remove(max_index);
|
|
||||||
Ok((best_response, infer_responses))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ChatTemplate {
|
|
||||||
template: Template<'static, 'static>,
|
|
||||||
bos_token: Option<String>,
|
|
||||||
eos_token: Option<String>,
|
|
||||||
use_default_tool_template: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatTemplate {
|
|
||||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
|
||||||
let mut env = Box::new(Environment::new());
|
|
||||||
let template_str = template.into_boxed_str();
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
|
||||||
|
|
||||||
// check if contains the tools variable within the template
|
|
||||||
let use_default_tool_template =
|
|
||||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
|
||||||
let template = Box::leak(env)
|
|
||||||
.template_from_str(Box::leak(template_str))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
template,
|
|
||||||
bos_token,
|
|
||||||
eos_token,
|
|
||||||
use_default_tool_template,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply(
|
|
||||||
&self,
|
|
||||||
mut messages: Vec<Message>,
|
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text(Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
self.template
|
|
||||||
.render(ChatTemplateInputs {
|
|
||||||
messages,
|
|
||||||
bos_token: self.bos_token.as_deref(),
|
|
||||||
eos_token: self.eos_token.as_deref(),
|
|
||||||
add_generation_prompt: true,
|
|
||||||
tools: None,
|
|
||||||
tools_prompt: None,
|
|
||||||
})
|
|
||||||
.map_err(InferError::TemplateError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ToolGrammar {}
|
|
||||||
|
|
||||||
impl ToolGrammar {
|
|
||||||
pub fn apply(
|
|
||||||
tools: Option<Vec<Tool>>,
|
|
||||||
tool_choice: Option<ToolType>,
|
|
||||||
) -> Result<Option<Tools>, InferError> {
|
|
||||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
|
||||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
|
||||||
let tools_to_use = match tool_choice {
|
|
||||||
ToolType::FunctionName(name) => {
|
|
||||||
vec![req_tools
|
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
|
||||||
.clone()]
|
|
||||||
}
|
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
|
||||||
"description".to_string(),
|
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
|
||||||
let properties = params
|
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": func.name.clone(),
|
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
|
||||||
let required = params
|
|
||||||
.entry("required".to_string())
|
|
||||||
.or_insert_with(|| json!([]))
|
|
||||||
.as_array_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
|
||||||
if !required.iter().any(|r| r == "_name") {
|
|
||||||
required.push(json!("_name"));
|
|
||||||
}
|
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
|
||||||
})
|
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = Tools {
|
|
||||||
functions_map: FunctionsMap { functions },
|
|
||||||
properties: Properties {
|
|
||||||
function: tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| FunctionRef {
|
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
|
||||||
})
|
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(Some(tools));
|
|
||||||
}
|
|
||||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn batching_task(
|
pub(crate) async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
@ -513,13 +29,13 @@ async fn batching_task(
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
notifier: Arc<Notify>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
// Wait for a notification from the Infer struct
|
// Wait for a notification from the Infer struct
|
||||||
shared.batching_task.notified().await;
|
notifier.notified().await;
|
||||||
|
|
||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
@ -880,28 +396,13 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Type alias for generation responses
|
impl From<text_generation_client::v2::GeneratedText> for GeneratedText {
|
||||||
pub(crate) type GenerateStreamResponse = (
|
fn from(value: text_generation_client::v2::GeneratedText) -> Self {
|
||||||
OwnedSemaphorePermit,
|
let v2_finish_reason = text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
u32, // input_length
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct GeneratedText {
|
|
||||||
pub(crate) text: String,
|
|
||||||
pub(crate) generated_tokens: u32,
|
|
||||||
pub(crate) finish_reason: FinishReason,
|
|
||||||
pub(crate) seed: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<v2::GeneratedText> for GeneratedText {
|
|
||||||
fn from(value: v2::GeneratedText) -> Self {
|
|
||||||
let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap();
|
|
||||||
let finish_reason = match v2_finish_reason {
|
let finish_reason = match v2_finish_reason {
|
||||||
v2::FinishReason::Length => FinishReason::Length,
|
text_generation_client::v2::FinishReason::Length => FinishReason::Length,
|
||||||
v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
v2::FinishReason::StopSequence => FinishReason::StopSequence,
|
text_generation_client::v2::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -913,68 +414,6 @@ impl From<v2::GeneratedText> for GeneratedText {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) enum InferStreamResponse {
|
|
||||||
// Optional first message
|
|
||||||
Prefill(Vec<PrefillToken>),
|
|
||||||
// Intermediate messages
|
|
||||||
Intermediate {
|
|
||||||
token: Token,
|
|
||||||
top_tokens: Vec<Token>,
|
|
||||||
},
|
|
||||||
// Last message
|
|
||||||
End {
|
|
||||||
token: Token,
|
|
||||||
top_tokens: Vec<Token>,
|
|
||||||
generated_text: GeneratedText,
|
|
||||||
start: Instant,
|
|
||||||
queued: Instant,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct InferResponse {
|
|
||||||
/// input_length is the input as perceived by the rust tokenizer in the
|
|
||||||
/// validation pathway. It is redundant with prefill.len() but prefill
|
|
||||||
/// has data only if the user asked for it. This will always be filled.
|
|
||||||
pub(crate) _input_length: u32,
|
|
||||||
pub(crate) prefill: Vec<PrefillToken>,
|
|
||||||
pub(crate) tokens: Vec<Token>,
|
|
||||||
pub(crate) generated_text: GeneratedText,
|
|
||||||
pub(crate) queued: Instant,
|
|
||||||
pub(crate) start: Instant,
|
|
||||||
pub(crate) top_tokens: Vec<Vec<Token>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum InferError {
|
|
||||||
#[error("Request failed during generation: {0}")]
|
|
||||||
GenerationError(String),
|
|
||||||
#[error("Model is overloaded")]
|
|
||||||
Overloaded(#[from] TryAcquireError),
|
|
||||||
#[error("Input validation error: {0}")]
|
|
||||||
ValidationError(#[from] ValidationError),
|
|
||||||
#[error("Incomplete generation")]
|
|
||||||
IncompleteGeneration,
|
|
||||||
#[error("Template error: {0}")]
|
|
||||||
TemplateError(#[from] minijinja::Error),
|
|
||||||
#[error("Tool error: {0}")]
|
|
||||||
ToolError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InferError {
|
|
||||||
pub(crate) fn error_type(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
InferError::GenerationError(_) => "generation",
|
|
||||||
InferError::Overloaded(_) => "overloaded",
|
|
||||||
InferError::ValidationError(_) => "validation",
|
|
||||||
InferError::IncompleteGeneration => "incomplete_generation",
|
|
||||||
InferError::TemplateError(_) => "template_error",
|
|
||||||
InferError::ToolError(_) => "tool_error",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tests
|
// tests
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
@ -1,5 +1,5 @@
|
|||||||
mod infer;
|
mod batcher;
|
||||||
mod queue;
|
mod queue;
|
||||||
|
|
||||||
pub(crate) use infer::{Infer, InferError, InferStreamResponse, InferResponse, ToolGrammar};
|
pub(crate) use batcher::batching_task;
|
||||||
pub(crate) use queue::{Entry, Queue};
|
pub(crate) use queue::Queue;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::infer::v2::{InferError, InferStreamResponse};
|
use crate::infer::{Entry, InferQueue};
|
||||||
use crate::validation::{
|
use crate::validation::{
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
@ -15,23 +15,6 @@ use tokio::sync::{mpsc, oneshot};
|
|||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
|
|
||||||
/// Queue entry
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct Entry {
|
|
||||||
/// Request
|
|
||||||
pub request: ValidGenerateRequest,
|
|
||||||
/// Response sender to communicate between the Infer struct and the batching_task
|
|
||||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
|
||||||
/// Span that will live as long as entry
|
|
||||||
pub span: Span,
|
|
||||||
/// Temporary span used as a guard when logging inference, wait times...
|
|
||||||
pub temp_span: Option<Span>,
|
|
||||||
/// Instant when this entry was queued
|
|
||||||
pub queue_time: Instant,
|
|
||||||
/// Instant when this entry was added to a batch
|
|
||||||
pub batch_time: Option<Instant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Queue {
|
pub(crate) struct Queue {
|
||||||
@ -39,6 +22,19 @@ pub(crate) struct Queue {
|
|||||||
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl InferQueue for Queue {
|
||||||
|
/// Append an entry to the queue
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn append(&self, entry: Entry) {
|
||||||
|
// Send append command to the background task managing the state
|
||||||
|
// Unwrap is safe here
|
||||||
|
self.queue_sender
|
||||||
|
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
@ -61,16 +57,6 @@ impl Queue {
|
|||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
|
||||||
// Send append command to the background task managing the state
|
|
||||||
// Unwrap is safe here
|
|
||||||
self.queue_sender
|
|
||||||
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::infer::HealthCheck;
|
use crate::infer::HealthCheck;
|
||||||
use crate::infer::v2::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
@ -34,7 +35,7 @@ use std::convert::Infallible;
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{v2::ShardedClient, ShardInfo, ClientError};
|
use text_generation_client::{v2::ShardedClient, ClientError};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
@ -1471,36 +1472,17 @@ pub async fn run(
|
|||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Open connection, get model info and warmup
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let (infer, health_ext, shard_info, max_batch_total_tokens) = {
|
||||||
.await
|
// Helper function to check both v2 and v3
|
||||||
.map_err(WebServerError::Connection)?;
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
// Clear the cache; useful if the webserver rebooted
|
match max_supported_batch_total_tokens {
|
||||||
sharded_client
|
|
||||||
.clear_cache(None)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Cache)?;
|
|
||||||
// Get info from the shard
|
|
||||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
|
||||||
|
|
||||||
// Warmup model
|
|
||||||
tracing::info!("Warming up model");
|
|
||||||
let max_batch_total_tokens = match sharded_client
|
|
||||||
.warmup(
|
|
||||||
max_input_tokens as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_total_tokens as u32,
|
|
||||||
max_batch_size,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Warmup)?
|
|
||||||
{
|
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
None => {
|
None => {
|
||||||
let max_batch_total_tokens = max_batch_total_tokens
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
max_batch_total_tokens
|
Ok(max_batch_total_tokens)
|
||||||
}
|
}
|
||||||
// Flash attention models return their max supported total tokens
|
// Flash attention models return their max supported total tokens
|
||||||
Some(max_supported_batch_total_tokens) => {
|
Some(max_supported_batch_total_tokens) => {
|
||||||
@ -1515,13 +1497,13 @@ pub async fn run(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
return Err(WebServerError::NotEnoughMemory(max_total_tokens))
|
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
max_supported_batch_total_tokens
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
@ -1536,6 +1518,33 @@ pub async fn run(
|
|||||||
grammar_support,
|
grammar_support,
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
// Try to open a v3 client
|
||||||
|
// Instantiate sharded client from the master unix socket
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Connection)?;
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Warmup)?)?;
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
let health_ext = HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
let health_ext = HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
sharded_client,
|
sharded_client,
|
||||||
@ -1554,6 +1563,9 @@ pub async fn run(
|
|||||||
processor_config,
|
processor_config,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
(infer, health_ext, shard_info, max_batch_total_tokens)
|
||||||
|
};
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
let duration_matcher = Matcher::Suffix(String::from("duration"));
|
let duration_matcher = Matcher::Suffix(String::from("duration"));
|
||||||
let n_duration_buckets = 35;
|
let n_duration_buckets = 35;
|
||||||
@ -1821,5 +1833,5 @@ pub enum WebServerError {
|
|||||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
NotEnoughMemory(usize),
|
NotEnoughMemory(usize),
|
||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
Axum(#[from] axum::BoxError)
|
Axum(#[from] axum::BoxError),
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user