From b5dd58f73b60005fa7203acbae124426d76b6149 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 30 Aug 2024 16:19:30 +0000 Subject: [PATCH] fix: enable chat requests in vertex endpoint --- router/src/lib.rs | 9 +++- router/src/server.rs | 119 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 107 insertions(+), 21 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index a1e1dadf..d8029c72 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -55,13 +55,20 @@ impl std::str::FromStr for Attention { } #[derive(Clone, Deserialize, ToSchema)] -pub(crate) struct VertexInstance { +pub(crate) struct GenerateVertexInstance { #[schema(example = "What is Deep Learning?")] pub inputs: String, #[schema(nullable = true, default = "null", example = "null")] pub parameters: Option, } +#[derive(Clone, Deserialize, ToSchema)] +#[serde(untagged)] +enum VertexInstance { + Generate(GenerateVertexInstance), + Chat(ChatRequest), +} + #[derive(Deserialize, ToSchema)] pub(crate) struct VertexRequest { #[serde(rename = "instances")] diff --git a/router/src/server.rs b/router/src/server.rs index d3d34215..10f67b9a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::{default_tool_prompt, ChatTokenizeResponse}; +use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -1406,30 +1406,106 @@ async fn vertex_compatibility( )); } - // Process all instances - let predictions = req + // Prepare futures for all instances + let futures: Vec<_> = req .instances .iter() .map(|instance| { - let generate_request = GenerateRequest { - inputs: instance.inputs.clone(), - add_special_tokens: true, - parameters: GenerateParameters { - do_sample: true, - max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), - seed: instance.parameters.as_ref().and_then(|p| p.seed), - details: true, - decoder_input_details: true, - ..Default::default() + let generate_request = match instance { + VertexInstance::Generate(instance) => GenerateRequest { + inputs: instance.inputs.clone(), + add_special_tokens: true, + parameters: GenerateParameters { + do_sample: true, + max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), + seed: instance.parameters.as_ref().and_then(|p| p.seed), + details: true, + decoder_input_details: true, + ..Default::default() + }, }, + VertexInstance::Chat(instance) => { + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + guideline, + presence_penalty, + frequency_penalty, + top_p, + top_logprobs, + .. + } = instance.clone(); + + let repetition_penalty = presence_penalty.map(|x| x + 2.0); + let max_new_tokens = max_tokens.or(Some(100)); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); + let stop = stop.unwrap_or_default(); + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; + let (inputs, grammar, _using_tools) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + guideline, + messages, + ) + .unwrap(); + + // build the request passing some parameters + GenerateRequest { + inputs: inputs.to_string(), + add_special_tokens: false, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty, + frequency_penalty, + top_k: None, + top_p, + typical_p: None, + do_sample, + max_new_tokens, + return_full_text: None, + stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: top_logprobs, + grammar, + adapter_id: model.filter(|m| *m != "tgi").map(String::from), + }, + } + } }; - async { + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let span_clone = span.clone(); + + async move { generate_internal( - Extension(infer.clone()), - compute_type.clone(), + Extension(infer_clone), + compute_type_clone, Json(generate_request), - span.clone(), + span_clone, ) .await .map(|(_, Json(generation))| generation.generated_text) @@ -1444,9 +1520,12 @@ async fn vertex_compatibility( }) } }) - .collect::>() - .try_collect::>() - .await?; + .collect(); + + // execute all futures in parallel, collect results, returning early if any error occurs + let results = futures::future::join_all(futures).await; + let predictions: Result, _> = results.into_iter().collect(); + let predictions = predictions?; let response = VertexResponse { predictions }; Ok((HeaderMap::new(), Json(response)).into_response())