feat: avoid unwrap and pre allocate future vec

This commit is contained in:
drbh 2024-08-30 16:47:24 +00:00
parent b5dd58f73b
commit ed8c7726ba

View File

@ -1407,120 +1407,127 @@ async fn vertex_compatibility(
} }
// Prepare futures for all instances // Prepare futures for all instances
let futures: Vec<_> = req let mut futures = Vec::with_capacity(req.instances.len());
.instances
.iter() for instance in req.instances.iter() {
.map(|instance| { let generate_request = match instance {
let generate_request = match instance { VertexInstance::Generate(instance) => GenerateRequest {
VertexInstance::Generate(instance) => GenerateRequest { inputs: instance.inputs.clone(),
inputs: instance.inputs.clone(), add_special_tokens: true,
add_special_tokens: true, parameters: GenerateParameters {
parameters: GenerateParameters { do_sample: true,
do_sample: true, max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), seed: instance.parameters.as_ref().and_then(|p| p.seed),
seed: instance.parameters.as_ref().and_then(|p| p.seed), details: true,
details: true, decoder_input_details: true,
decoder_input_details: true, ..Default::default()
..Default::default()
},
}, },
VertexInstance::Chat(instance) => { },
let ChatRequest { VertexInstance::Chat(instance) => {
model, let ChatRequest {
max_tokens, model,
messages, max_tokens,
seed, messages,
stop, seed,
stream, stop,
tools, stream,
tool_choice, tools,
tool_prompt, tool_choice,
temperature, tool_prompt,
response_format, temperature,
guideline, response_format,
presence_penalty, guideline,
frequency_penalty, presence_penalty,
top_p, frequency_penalty,
top_logprobs, top_p,
.. top_logprobs,
} = instance.clone(); ..
} = instance.clone();
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100)); let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt let tool_prompt = tool_prompt
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt); .unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default(); let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0 // enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature { let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None), Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other), other => (true, other),
}; };
let (inputs, grammar, _using_tools) = prepare_chat_input( let (inputs, grammar, _using_tools) = match prepare_chat_input(
&infer, &infer,
response_format, response_format,
tools, tools,
tool_choice, tool_choice,
&tool_prompt, &tool_prompt,
guideline, guideline,
messages, messages,
) ) {
.unwrap(); Ok(result) => result,
Err(e) => {
// build the request passing some parameters return Err((
GenerateRequest { StatusCode::BAD_REQUEST,
inputs: inputs.to_string(), Json(ErrorResponse {
add_special_tokens: false, error: format!("Failed to prepare chat input: {}", e),
parameters: GenerateParameters { error_type: "Input preparation error".to_string(),
best_of: None, }),
temperature, ));
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
} }
};
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
} }
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
} }
}) };
.collect();
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs // execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await; let results = futures::future::join_all(futures).await;