mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
wip
This commit is contained in:
parent
be2d38032a
commit
504754861f
@ -1,511 +1,84 @@
|
|||||||
/// Batching and inference logic
|
use crate::infer::InferError;
|
||||||
use crate::infer::v3::queue::{Entry, Queue};
|
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage};
|
||||||
use crate::infer::{
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
use minijinja_contrib::pycompat;
|
||||||
};
|
|
||||||
use crate::validation::ValidGenerateRequest;
|
|
||||||
use crate::{FinishReason, PrefillToken, Token};
|
|
||||||
use nohash_hasher::IntMap;
|
|
||||||
use std::sync::{
|
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
};
|
|
||||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
|
||||||
use text_generation_client::ClientError;
|
|
||||||
use tokio::sync::mpsc::error::SendError;
|
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
|
||||||
|
|
||||||
pub(crate) struct SchedulerV3 {
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
/// Request queue
|
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
queue: Queue,
|
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||||
/// Notify batcher on queue appends
|
|
||||||
batching_task_notifier: Arc<Notify>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SchedulerV3 {
|
#[derive(Clone)]
|
||||||
#[allow(clippy::too_many_arguments)]
|
pub(crate) struct ChatTemplate {
|
||||||
|
template: Template<'static, 'static>,
|
||||||
|
bos_token: Option<String>,
|
||||||
|
eos_token: Option<String>,
|
||||||
|
use_default_tool_template: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatTemplate {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
template: String,
|
||||||
waiting_served_ratio: f32,
|
bos_token: Option<String>,
|
||||||
max_batch_prefill_tokens: u32,
|
eos_token: Option<String>,
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
requires_padding: bool,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let queue = Queue::new(
|
let mut env = Box::new(Environment::new());
|
||||||
requires_padding,
|
// enable things like .strip() or .capitalize()
|
||||||
16,
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
window_size,
|
let template_str = template.into_boxed_str();
|
||||||
speculate,
|
env.add_function("raise_exception", raise_exception);
|
||||||
max_batch_total_tokens,
|
|
||||||
);
|
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// check if contains the tools variable within the template
|
||||||
tokio::spawn(batching_task(
|
let use_default_tool_template =
|
||||||
client,
|
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||||
waiting_served_ratio,
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
max_batch_prefill_tokens,
|
let template = Box::leak(env)
|
||||||
max_batch_total_tokens,
|
.template_from_str(Box::leak(template_str))
|
||||||
max_waiting_tokens,
|
.unwrap();
|
||||||
max_batch_size,
|
|
||||||
queue.clone(),
|
|
||||||
batching_task_notifier.clone(),
|
|
||||||
generation_health,
|
|
||||||
));
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
queue,
|
template,
|
||||||
batching_task_notifier,
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
use_default_tool_template,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Scheduler for SchedulerV3 {
|
pub(crate) fn apply(
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn schedule(
|
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
mut messages: Vec<Message>,
|
||||||
permit: OwnedSemaphorePermit,
|
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<String, InferError> {
|
||||||
// MPSC channel to communicate with the background batching task
|
if self.use_default_tool_template {
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
if let Some(last_message) = messages.last_mut() {
|
||||||
let input_length = request.input_length;
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
|
last_message.content.push(MessageChunk::Text(Text {
|
||||||
// Append the request to the queue
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
self.queue.append(Entry {
|
}));
|
||||||
request,
|
|
||||||
response_tx,
|
|
||||||
span: Span::current(),
|
|
||||||
temp_span: None,
|
|
||||||
queue_time: Instant::now(),
|
|
||||||
batch_time: None,
|
|
||||||
block_allocation: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
|
||||||
// to be batched
|
|
||||||
self.batching_task_notifier.notify_one();
|
|
||||||
|
|
||||||
// Return stream
|
|
||||||
Ok((
|
|
||||||
permit,
|
|
||||||
input_length,
|
|
||||||
UnboundedReceiverStream::new(response_rx),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Batching logic
|
|
||||||
/// Will be launched in a background Tokio task
|
|
||||||
///
|
|
||||||
/// Batches requests and sends them to the inference server
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub(crate) async fn batching_task(
|
|
||||||
mut client: ShardedClient,
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
queue: Queue,
|
|
||||||
notifier: Arc<Notify>,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
) {
|
|
||||||
// Infinite loop
|
|
||||||
loop {
|
|
||||||
// Wait for a notification from the Infer struct
|
|
||||||
notifier.notified().await;
|
|
||||||
|
|
||||||
// Get the next batch from the queue
|
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
|
||||||
// waiting in the queue
|
|
||||||
while let Some((mut entries, batch, span)) = queue
|
|
||||||
.next_batch(
|
|
||||||
None,
|
|
||||||
max_batch_size,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
|
||||||
.instrument(span)
|
|
||||||
.await;
|
|
||||||
let mut waiting_tokens = 1;
|
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
|
||||||
// all requests have met their stopping criteria)
|
|
||||||
while let Some(batch) = cached_batch {
|
|
||||||
// Get current batch info
|
|
||||||
let batch_size = batch.size;
|
|
||||||
let batch_max_tokens = batch.max_tokens;
|
|
||||||
let mut batches = vec![batch];
|
|
||||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
||||||
// to add a new batch even though its size might be small
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
// Minimum batch size
|
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
|
||||||
};
|
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
|
||||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
|
||||||
|
|
||||||
// Try to get a new batch
|
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
// Tracking metrics
|
|
||||||
if min_size.is_some() {
|
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
|
||||||
} else {
|
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
|
||||||
}
|
|
||||||
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
|
||||||
// Create a new span to add the info that this entry is waiting
|
|
||||||
// because a new batch is being computed
|
|
||||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
|
||||||
// Add relationships
|
|
||||||
span.follows_from(&entry_waiting_span);
|
|
||||||
entry_waiting_span.follows_from(&span);
|
|
||||||
// Update entry
|
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
|
||||||
let new_cached_batch =
|
|
||||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
|
||||||
.instrument(span)
|
|
||||||
.await;
|
|
||||||
// Reset waiting counter
|
|
||||||
waiting_tokens = 1;
|
|
||||||
// Extend current batch with the new batch
|
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
|
||||||
entries.extend(new_entries);
|
|
||||||
batches.push(new_cached_batch);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
let next_batch_size = entries.len();
|
|
||||||
let next_batch_span =
|
|
||||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
|
||||||
// Create a new span to link the batch back to this entry
|
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
|
||||||
// Add relationships
|
|
||||||
next_batch_span.follows_from(&entry_batch_span);
|
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
|
||||||
// Update entry
|
|
||||||
entry.temp_span = Some(entry_batch_span);
|
|
||||||
});
|
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
self.template
|
||||||
.instrument(next_batch_span)
|
.render(ChatTemplateInputs {
|
||||||
.await;
|
messages,
|
||||||
waiting_tokens += 1;
|
bos_token: self.bos_token.as_deref(),
|
||||||
}
|
eos_token: self.eos_token.as_deref(),
|
||||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
add_generation_prompt: true,
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
tools: None,
|
||||||
}
|
tools_prompt: None,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn prefill(
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
batch: Batch,
|
|
||||||
entries: &mut IntMap<u64, Entry>,
|
|
||||||
generation_health: &Arc<AtomicBool>,
|
|
||||||
) -> Option<CachedBatch> {
|
|
||||||
let start_time = Instant::now();
|
|
||||||
let batch_id = batch.id;
|
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
|
||||||
Ok((generations, next_batch, timings)) => {
|
|
||||||
// Update health
|
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
|
||||||
|
|
||||||
let start_filtering_time = Instant::now();
|
|
||||||
// Send generated tokens and filter stopped entries
|
|
||||||
filter_send_generations(generations, entries);
|
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
|
||||||
next_batch
|
|
||||||
}
|
|
||||||
// If we have an error, we discard the whole batch
|
|
||||||
Err(err) => {
|
|
||||||
// Update health
|
|
||||||
generation_health.store(false, Ordering::SeqCst);
|
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
|
||||||
send_errors(err, entries);
|
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn decode(
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
batches: Vec<CachedBatch>,
|
|
||||||
entries: &mut IntMap<u64, Entry>,
|
|
||||||
generation_health: &Arc<AtomicBool>,
|
|
||||||
) -> Option<CachedBatch> {
|
|
||||||
let start_time = Instant::now();
|
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
|
||||||
|
|
||||||
match client.decode(batches).await {
|
|
||||||
Ok((generations, next_batch, timings)) => {
|
|
||||||
// Update health
|
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
|
||||||
|
|
||||||
let start_filtering_time = Instant::now();
|
|
||||||
// Send generated tokens and filter stopped entries
|
|
||||||
filter_send_generations(generations, entries);
|
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
|
||||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
|
||||||
}
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
|
||||||
next_batch
|
|
||||||
}
|
|
||||||
// If we have an error, we discard the whole batch
|
|
||||||
Err(err) => {
|
|
||||||
generation_health.store(false, Ordering::SeqCst);
|
|
||||||
for id in batch_ids {
|
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
|
||||||
}
|
|
||||||
send_errors(err, entries);
|
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Filter a `batch` and remove all requests not present in `entries`
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn filter_batch(
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
next_batch: Option<CachedBatch>,
|
|
||||||
entries: &IntMap<u64, Entry>,
|
|
||||||
) -> Option<CachedBatch> {
|
|
||||||
let mut batch = next_batch?;
|
|
||||||
|
|
||||||
// No need to filter
|
|
||||||
if batch.size as usize == entries.len() {
|
|
||||||
return Some(batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
let id = batch.id;
|
|
||||||
|
|
||||||
// Retain only requests that are still in entries
|
|
||||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
|
||||||
|
|
||||||
if batch.request_ids.is_empty() {
|
|
||||||
// All requests have been filtered out
|
|
||||||
// Next batch is now empty
|
|
||||||
// Clear it from the Python shards cache
|
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
||||||
client.clear_cache(Some(id)).await.unwrap();
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
// Filter Python shard cache
|
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
||||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
|
||||||
/// and filter entries
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
|
||||||
generations.into_iter().for_each(|generation| {
|
|
||||||
let id = generation.request_id;
|
|
||||||
// Get entry
|
|
||||||
// We can `expect` here as the request id should always be in the entries
|
|
||||||
let entry = entries
|
|
||||||
.get(&id)
|
|
||||||
.expect("ID not found in entries. This is a bug.");
|
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
|
||||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
|
||||||
// Send generation responses back to the infer task
|
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
|
||||||
tracing::error!("Entry response channel error.");
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
err
|
|
||||||
}).unwrap_or(true);
|
|
||||||
if stopped {
|
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
|
||||||
fn send_responses(
|
|
||||||
generation: Generation,
|
|
||||||
entry: &Entry,
|
|
||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
|
||||||
// Return directly if the channel is disconnected
|
|
||||||
if entry.response_tx.is_closed() {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
return Ok(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut stopped = false;
|
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
|
||||||
// Create Token objects
|
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
|
||||||
let prefill_tokens = prefill_tokens
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(prefill_tokens.logprobs)
|
|
||||||
.zip(prefill_tokens.texts)
|
|
||||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Send message
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create last Token
|
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
|
||||||
let n = tokens_.ids.len();
|
|
||||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
|
||||||
let mut iterator = tokens_
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(tokens_.logprobs)
|
|
||||||
.zip(tokens_.texts)
|
|
||||||
.zip(tokens_.is_special)
|
|
||||||
.enumerate()
|
|
||||||
.peekable();
|
|
||||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
|
||||||
let token = Token {
|
|
||||||
id,
|
|
||||||
text,
|
|
||||||
logprob,
|
|
||||||
special,
|
|
||||||
};
|
|
||||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
|
||||||
top_tokens_
|
|
||||||
.ids
|
|
||||||
.iter()
|
|
||||||
.zip(top_tokens_.logprobs.iter())
|
|
||||||
.zip(top_tokens_.texts.iter())
|
|
||||||
.zip(top_tokens_.is_special.iter())
|
|
||||||
.map(|(((&id, &logprob), text), &special)| Token {
|
|
||||||
id,
|
|
||||||
text: text.to_string(),
|
|
||||||
logprob,
|
|
||||||
special,
|
|
||||||
})
|
})
|
||||||
.collect()
|
.map_err(InferError::TemplateError)
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
|
||||||
match (&generation.generated_text, iterator.peek()) {
|
|
||||||
(Some(generated_text), None) => {
|
|
||||||
// Generation has ended
|
|
||||||
stopped = true;
|
|
||||||
// Send message
|
|
||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
top_tokens,
|
|
||||||
generated_text: GeneratedText::from(generated_text.clone()),
|
|
||||||
queued: entry.queue_time,
|
|
||||||
start: entry.batch_time.unwrap(),
|
|
||||||
}))?;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Send message
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(stopped)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send errors to Infer for all `entries`
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|
||||||
entries.drain().for_each(|(_, entry)| {
|
|
||||||
// Create and enter a span to link this function back to the entry
|
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
|
||||||
let err = InferError::GenerationError(error.to_string());
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Err(err))
|
|
||||||
.unwrap_or(());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
|
||||||
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
|
||||||
let v3_finish_reason =
|
|
||||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
|
||||||
let finish_reason = match v3_finish_reason {
|
|
||||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
|
||||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
|
||||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
text: value.text,
|
|
||||||
generated_tokens: value.generated_tokens,
|
|
||||||
finish_reason,
|
|
||||||
seed: value.seed,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// tests
|
// tests
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::infer::raise_exception;
|
use crate::infer::chat_template::raise_exception;
|
||||||
use crate::{ChatTemplateInputs, TextMessage};
|
use crate::{ChatTemplateInputs, TextMessage};
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
@ -1,34 +0,0 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::Health;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub(crate) struct HealthCheck {
|
|
||||||
client: Arc<dyn Health + Send + Sync>,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HealthCheck {
|
|
||||||
pub(crate) fn new(
|
|
||||||
client: Arc<dyn Health + Send + Sync>,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
client,
|
|
||||||
generation_health,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn check(&mut self) -> bool {
|
|
||||||
let value = if self.generation_health.load(Ordering::SeqCst) {
|
|
||||||
// Generation is healthy, we only check that the shards can allocate on device
|
|
||||||
self.client.device_health().await
|
|
||||||
} else {
|
|
||||||
self.client.model_health().await
|
|
||||||
}
|
|
||||||
.is_ok();
|
|
||||||
// Update generation health
|
|
||||||
self.generation_health.store(value, Ordering::SeqCst);
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,6 +1,8 @@
|
|||||||
mod health;
|
mod health;
|
||||||
pub(crate) mod v2;
|
pub(crate) mod v2;
|
||||||
pub(crate) mod v3;
|
pub(crate) mod v3;
|
||||||
|
mod chat_template;
|
||||||
|
mod tool_grammar;
|
||||||
|
|
||||||
pub(crate) use health::HealthCheck;
|
pub(crate) use health::HealthCheck;
|
||||||
|
|
||||||
@ -23,6 +25,7 @@ use tokio::time::Instant;
|
|||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use chat_template::ChatTemplate;
|
||||||
|
|
||||||
pub(crate) trait Scheduler {
|
pub(crate) trait Scheduler {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
@ -37,18 +40,20 @@ pub(crate) trait Scheduler {
|
|||||||
pub struct Infer {
|
pub struct Infer {
|
||||||
/// Validation
|
/// Validation
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
/// Request scheduler
|
/// Request backend
|
||||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
backend: Arc<dyn Backend + Send + Sync>,
|
||||||
/// Chat template
|
/// Chat template
|
||||||
chat_template: Option<ChatTemplate>,
|
chat_template: Option<ChatTemplate>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
/// Backend health
|
||||||
|
backend_health: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
backend: impl Backend + Send + Sync + 'static,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
@ -69,20 +74,31 @@ impl Infer {
|
|||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
// Backend health
|
||||||
|
let backend_health = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
scheduler,
|
backend: Arc::new(backend),
|
||||||
chat_template,
|
chat_template,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
|
backend_health,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
impl Stream<Item=Result<InferStreamResponse, InferError>> + 'a,
|
||||||
|
),
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
@ -101,7 +117,36 @@ impl Infer {
|
|||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
self.scheduler.schedule(valid_request, permit)
|
let input_length = valid_request.input_length;
|
||||||
|
let mut generation_stream = self
|
||||||
|
.backend
|
||||||
|
.schedule(valid_request)
|
||||||
|
.map_err(InferError::Backend)?;
|
||||||
|
|
||||||
|
let stream = stream! {
|
||||||
|
while let Some(generation) = generation_stream.next().await {
|
||||||
|
self.backend_health.store(generation.is_ok(), Ordering::SeqCst);
|
||||||
|
|
||||||
|
yield generation.map(|generation| match generation {
|
||||||
|
types::TokenStreamResponse::Prefill(prefill_tokens) => InferStreamResponse::Prefill(
|
||||||
|
prefill_tokens.into_iter().map(PrefillToken::from).collect()
|
||||||
|
),
|
||||||
|
types::TokenStreamResponse::Intermediate { token, top_tokens } => InferStreamResponse::Intermediate {
|
||||||
|
token: Token::from(token),
|
||||||
|
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
|
||||||
|
},
|
||||||
|
types::TokenStreamResponse::End { token, top_tokens, generated_text, start, queued } => InferStreamResponse::End {
|
||||||
|
token: Token::from(token),
|
||||||
|
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
|
||||||
|
generated_text,
|
||||||
|
start,
|
||||||
|
queued,
|
||||||
|
}
|
||||||
|
}).map_err(InferError::GenerationError)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((permit, input_length, stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
/// Tokenizer the input
|
||||||
@ -153,7 +198,7 @@ impl Infer {
|
|||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
@ -163,6 +208,8 @@ impl Infer {
|
|||||||
let mut result_start = None;
|
let mut result_start = None;
|
||||||
let mut result_queued = None;
|
let mut result_queued = None;
|
||||||
|
|
||||||
|
let mut stream = Box::pin(stream);
|
||||||
|
|
||||||
// Iterate on stream
|
// Iterate on stream
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response? {
|
match response? {
|
||||||
@ -254,190 +301,15 @@ impl Infer {
|
|||||||
let best_response = infer_responses.remove(max_index);
|
let best_response = infer_responses.remove(max_index);
|
||||||
Ok((best_response, infer_responses))
|
Ok((best_response, infer_responses))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Raise a exception (custom function) used in the chat templates
|
#[instrument(skip(self))]
|
||||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
pub(crate) async fn health(&self) -> bool {
|
||||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
let health = self
|
||||||
}
|
.backend
|
||||||
|
.health(self.backend_health.load(Ordering::SeqCst))
|
||||||
#[derive(Clone)]
|
.await;
|
||||||
struct ChatTemplate {
|
self.backend_health.store(health, Ordering::SeqCst);
|
||||||
template: Template<'static, 'static>,
|
health
|
||||||
bos_token: Option<String>,
|
|
||||||
eos_token: Option<String>,
|
|
||||||
use_default_tool_template: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatTemplate {
|
|
||||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
|
||||||
let mut env = Box::new(Environment::new());
|
|
||||||
// enable things like .strip() or .capitalize()
|
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
|
||||||
let template_str = template.into_boxed_str();
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
|
||||||
|
|
||||||
// check if contains the tools variable within the template
|
|
||||||
let use_default_tool_template =
|
|
||||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
|
||||||
let template = Box::leak(env)
|
|
||||||
.template_from_str(Box::leak(template_str))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
template,
|
|
||||||
bos_token,
|
|
||||||
eos_token,
|
|
||||||
use_default_tool_template,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply(
|
|
||||||
&self,
|
|
||||||
mut messages: Vec<Message>,
|
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text(Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
self.template
|
|
||||||
.render(ChatTemplateInputs {
|
|
||||||
messages,
|
|
||||||
bos_token: self.bos_token.as_deref(),
|
|
||||||
eos_token: self.eos_token.as_deref(),
|
|
||||||
add_generation_prompt: true,
|
|
||||||
tools: None,
|
|
||||||
tools_prompt: None,
|
|
||||||
})
|
|
||||||
.map_err(InferError::TemplateError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ToolGrammar {}
|
|
||||||
|
|
||||||
impl ToolGrammar {
|
|
||||||
pub fn apply(
|
|
||||||
tools: Option<Vec<Tool>>,
|
|
||||||
tool_choice: Option<ToolType>,
|
|
||||||
) -> Result<Option<Tools>, InferError> {
|
|
||||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
|
||||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
|
||||||
let tools_to_use = match tool_choice {
|
|
||||||
ToolType::FunctionName(name) => {
|
|
||||||
vec![req_tools
|
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
|
||||||
.clone()]
|
|
||||||
}
|
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
|
||||||
"description".to_string(),
|
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
|
||||||
let properties = params
|
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": func.name.clone(),
|
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
|
||||||
let required = params
|
|
||||||
.entry("required".to_string())
|
|
||||||
.or_insert_with(|| json!([]))
|
|
||||||
.as_array_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
|
||||||
if !required.iter().any(|r| r == "_name") {
|
|
||||||
required.push(json!("_name"));
|
|
||||||
}
|
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
|
||||||
})
|
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = Tools {
|
|
||||||
functions_map: FunctionsMap { functions },
|
|
||||||
properties: Properties {
|
|
||||||
function: tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| FunctionRef {
|
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
|
||||||
})
|
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(Some(tools));
|
|
||||||
}
|
|
||||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
|
||||||
Ok(None)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -491,8 +363,10 @@ pub(crate) struct InferResponse {
|
|||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum InferError {
|
pub enum InferError {
|
||||||
|
#[error("Request failed during scheduling: {0}")]
|
||||||
|
Backend(BackendError),
|
||||||
#[error("Request failed during generation: {0}")]
|
#[error("Request failed during generation: {0}")]
|
||||||
GenerationError(String),
|
GenerationError(BackendError),
|
||||||
#[error("Model is overloaded")]
|
#[error("Model is overloaded")]
|
||||||
Overloaded(#[from] TryAcquireError),
|
Overloaded(#[from] TryAcquireError),
|
||||||
#[error("Input validation error: {0}")]
|
#[error("Input validation error: {0}")]
|
||||||
@ -508,6 +382,7 @@ pub enum InferError {
|
|||||||
impl InferError {
|
impl InferError {
|
||||||
pub(crate) fn error_type(&self) -> &str {
|
pub(crate) fn error_type(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
|
InferError::Backend(_) => "backend",
|
||||||
InferError::GenerationError(_) => "generation",
|
InferError::GenerationError(_) => "generation",
|
||||||
InferError::Overloaded(_) => "overloaded",
|
InferError::Overloaded(_) => "overloaded",
|
||||||
InferError::ValidationError(_) => "validation",
|
InferError::ValidationError(_) => "validation",
|
||||||
|
122
router/src/infer/tool_grammar.rs
Normal file
122
router/src/infer/tool_grammar.rs
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
use crate::infer::InferError;
|
||||||
|
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools};
|
||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub(crate) struct ToolGrammar {}
|
||||||
|
|
||||||
|
impl ToolGrammar {
|
||||||
|
pub fn apply(
|
||||||
|
tools: Option<Vec<Tool>>,
|
||||||
|
tool_choice: Option<ToolType>,
|
||||||
|
) -> Result<Option<Tools>, InferError> {
|
||||||
|
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||||
|
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
|
let tools_to_use = match tool_choice {
|
||||||
|
ToolType::FunctionName(name) => {
|
||||||
|
vec![req_tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == *name)
|
||||||
|
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||||
|
.clone()]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => req_tools.to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// adds the error notification function for LLM feedback if required
|
||||||
|
let mut text_response_properties = Map::new();
|
||||||
|
text_response_properties.insert(
|
||||||
|
"error".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
text_response_properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": "notify_error"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let functions: HashMap<String, Value> = tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
let func = tool.function.clone();
|
||||||
|
|
||||||
|
// Clone the existing parameters, which are expected to be a JSON object
|
||||||
|
let mut params = if let Value::Object(params) = &func.arguments {
|
||||||
|
params.clone()
|
||||||
|
} else {
|
||||||
|
Map::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Insert the function's description at the top level, outside of properties
|
||||||
|
params.insert(
|
||||||
|
"description".to_string(),
|
||||||
|
Value::String(func.description.clone().unwrap_or_default()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ensure 'properties' exists and is an object
|
||||||
|
let properties = params
|
||||||
|
.entry("properties".to_string())
|
||||||
|
.or_insert_with(|| json!({}))
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Insert the constant for the function name inside 'properties'
|
||||||
|
properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": func.name.clone(),
|
||||||
|
// "description": "The name of the function"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||||
|
let required = params
|
||||||
|
.entry("required".to_string())
|
||||||
|
.or_insert_with(|| json!([]))
|
||||||
|
.as_array_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Add 'name' to the 'required' array if it is not already present
|
||||||
|
if !required.iter().any(|r| r == "_name") {
|
||||||
|
required.push(json!("_name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
(func.name, Value::Object(params))
|
||||||
|
})
|
||||||
|
.chain([(
|
||||||
|
"notify_error".to_string(),
|
||||||
|
json!({
|
||||||
|
"properties": text_response_properties,
|
||||||
|
"required": ["error", "_name"],
|
||||||
|
"type": "object"
|
||||||
|
}),
|
||||||
|
)])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
functions_map: FunctionsMap { functions },
|
||||||
|
properties: Properties {
|
||||||
|
function: tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.chain(std::iter::once(FunctionRef {
|
||||||
|
ref_path: "#/$functions/notify_error".to_string(),
|
||||||
|
}))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok(Some(tools));
|
||||||
|
}
|
||||||
|
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
@ -121,12 +121,10 @@ responses(
|
|||||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(health))]
|
#[instrument(skip(infer))]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
async fn health(
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
mut health: Extension<HealthCheck>,
|
match infer.health().await {
|
||||||
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
|
||||||
match health.check().await {
|
|
||||||
true => Ok(()),
|
true => Ok(()),
|
||||||
false => Err((
|
false => Err((
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
@ -437,8 +435,9 @@ async fn generate_stream_internal(
|
|||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, _input_length, mut response_stream)) => {
|
Ok((_permit, _input_length, response_stream)) => {
|
||||||
let mut index = 0;
|
let mut index = 0;
|
||||||
|
let mut response_stream = Box::pin(response_stream);
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
index += 1;
|
index += 1;
|
||||||
@ -1960,16 +1959,8 @@ impl From<InferError> for Event {
|
|||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum WebServerError {
|
pub enum WebServerError {
|
||||||
#[error("Unable to connect to the Python model shards: {0}")]
|
#[error("Backend error: {0}")]
|
||||||
Connection(ClientError),
|
Backend(#[from] BackendError),
|
||||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
|
||||||
Cache(ClientError),
|
|
||||||
#[error("Unable to get the Python model shards info: {0}")]
|
|
||||||
Info(ClientError),
|
|
||||||
#[error("Unable to warmup the Python model shards: {0}")]
|
|
||||||
Warmup(ClientError),
|
|
||||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
|
||||||
NotEnoughMemory(usize),
|
|
||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
Axum(#[from] axum::BoxError),
|
Axum(#[from] axum::BoxError),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user