diff --git a/router/src/server.rs b/router/src/server.rs index 10f67b9a..fac56a77 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1407,120 +1407,127 @@ async fn vertex_compatibility( } // Prepare futures for all instances - let futures: Vec<_> = req - .instances - .iter() - .map(|instance| { - let generate_request = match instance { - VertexInstance::Generate(instance) => GenerateRequest { - inputs: instance.inputs.clone(), - add_special_tokens: true, - parameters: GenerateParameters { - do_sample: true, - max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), - seed: instance.parameters.as_ref().and_then(|p| p.seed), - details: true, - decoder_input_details: true, - ..Default::default() - }, + let mut futures = Vec::with_capacity(req.instances.len()); + + for instance in req.instances.iter() { + let generate_request = match instance { + VertexInstance::Generate(instance) => GenerateRequest { + inputs: instance.inputs.clone(), + add_special_tokens: true, + parameters: GenerateParameters { + do_sample: true, + max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), + seed: instance.parameters.as_ref().and_then(|p| p.seed), + details: true, + decoder_input_details: true, + ..Default::default() }, - VertexInstance::Chat(instance) => { - let ChatRequest { - model, - max_tokens, - messages, - seed, - stop, - stream, - tools, - tool_choice, - tool_prompt, - temperature, - response_format, - guideline, - presence_penalty, - frequency_penalty, - top_p, - top_logprobs, - .. - } = instance.clone(); + }, + VertexInstance::Chat(instance) => { + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + guideline, + presence_penalty, + frequency_penalty, + top_p, + top_logprobs, + .. + } = instance.clone(); - let repetition_penalty = presence_penalty.map(|x| x + 2.0); - let max_new_tokens = max_tokens.or(Some(100)); - let tool_prompt = tool_prompt - .filter(|s| !s.is_empty()) - .unwrap_or_else(default_tool_prompt); - let stop = stop.unwrap_or_default(); - // enable greedy only when temperature is 0 - let (do_sample, temperature) = match temperature { - Some(temperature) if temperature == 0.0 => (false, None), - other => (true, other), - }; - let (inputs, grammar, _using_tools) = prepare_chat_input( - &infer, - response_format, - tools, - tool_choice, - &tool_prompt, - guideline, - messages, - ) - .unwrap(); - - // build the request passing some parameters - GenerateRequest { - inputs: inputs.to_string(), - add_special_tokens: false, - parameters: GenerateParameters { - best_of: None, - temperature, - repetition_penalty, - frequency_penalty, - top_k: None, - top_p, - typical_p: None, - do_sample, - max_new_tokens, - return_full_text: None, - stop, - truncate: None, - watermark: false, - details: true, - decoder_input_details: !stream, - seed, - top_n_tokens: top_logprobs, - grammar, - adapter_id: model.filter(|m| *m != "tgi").map(String::from), - }, + let repetition_penalty = presence_penalty.map(|x| x + 2.0); + let max_new_tokens = max_tokens.or(Some(100)); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); + let stop = stop.unwrap_or_default(); + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; + let (inputs, grammar, _using_tools) = match prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + guideline, + messages, + ) { + Ok(result) => result, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: format!("Failed to prepare chat input: {}", e), + error_type: "Input preparation error".to_string(), + }), + )); } + }; + + GenerateRequest { + inputs: inputs.to_string(), + add_special_tokens: false, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty, + frequency_penalty, + top_k: None, + top_p, + typical_p: None, + do_sample, + max_new_tokens, + return_full_text: None, + stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: top_logprobs, + grammar, + adapter_id: model.filter(|m| *m != "tgi").map(String::from), + }, } - }; - - let infer_clone = infer.clone(); - let compute_type_clone = compute_type.clone(); - let span_clone = span.clone(); - - async move { - generate_internal( - Extension(infer_clone), - compute_type_clone, - Json(generate_request), - span_clone, - ) - .await - .map(|(_, Json(generation))| generation.generated_text) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) } - }) - .collect(); + }; + + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let span_clone = span.clone(); + + futures.push(async move { + generate_internal( + Extension(infer_clone), + compute_type_clone, + Json(generate_request), + span_clone, + ) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + }); + } // execute all futures in parallel, collect results, returning early if any error occurs let results = futures::future::join_all(futures).await;