fix: enable chat requests in vertex endpoint

This commit is contained in:
drbh 2024-08-30 16:19:30 +00:00
parent d9fbbaafb0
commit b5dd58f73b
2 changed files with 107 additions and 21 deletions

View File

@ -55,13 +55,20 @@ impl std::str::FromStr for Attention {
} }
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance { pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub inputs: String, pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>, pub parameters: Option<GenerateParameters>,
} }
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)] #[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest { pub(crate) struct VertexRequest {
#[serde(rename = "instances")] #[serde(rename = "instances")]

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -1406,12 +1406,13 @@ async fn vertex_compatibility(
)); ));
} }
// Process all instances // Prepare futures for all instances
let predictions = req let futures: Vec<_> = req
.instances .instances
.iter() .iter()
.map(|instance| { .map(|instance| {
let generate_request = GenerateRequest { let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(), inputs: instance.inputs.clone(),
add_special_tokens: true, add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
@ -1422,14 +1423,89 @@ async fn vertex_compatibility(
decoder_input_details: true, decoder_input_details: true,
..Default::default() ..Default::default()
}, },
},
VertexInstance::Chat(instance) => {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.clone();
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, _using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)
.unwrap();
// build the request passing some parameters
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
}
}
}; };
async { let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
async move {
generate_internal( generate_internal(
Extension(infer.clone()), Extension(infer_clone),
compute_type.clone(), compute_type_clone,
Json(generate_request), Json(generate_request),
span.clone(), span_clone,
) )
.await .await
.map(|(_, Json(generation))| generation.generated_text) .map(|(_, Json(generation))| generation.generated_text)
@ -1444,9 +1520,12 @@ async fn vertex_compatibility(
}) })
} }
}) })
.collect::<FuturesUnordered<_>>() .collect();
.try_collect::<Vec<_>>()
.await?; // execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions }; let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response()) Ok((HeaderMap::new(), Json(response)).into_response())