mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: avoid unwrap and pre allocate future vec
This commit is contained in:
parent
b5dd58f73b
commit
ed8c7726ba
@ -1407,120 +1407,127 @@ async fn vertex_compatibility(
|
||||
}
|
||||
|
||||
// Prepare futures for all instances
|
||||
let futures: Vec<_> = req
|
||||
.instances
|
||||
.iter()
|
||||
.map(|instance| {
|
||||
let generate_request = match instance {
|
||||
VertexInstance::Generate(instance) => GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||
details: true,
|
||||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
let mut futures = Vec::with_capacity(req.instances.len());
|
||||
|
||||
for instance in req.instances.iter() {
|
||||
let generate_request = match instance {
|
||||
VertexInstance::Generate(instance) => GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||
details: true,
|
||||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
VertexInstance::Chat(instance) => {
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
..
|
||||
} = instance.clone();
|
||||
},
|
||||
VertexInstance::Chat(instance) => {
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
..
|
||||
} = instance.clone();
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, _using_tools) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// build the request passing some parameters
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k: None,
|
||||
top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to prepare chat input: {}", e),
|
||||
error_type: "Input preparation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k: None,
|
||||
top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
futures.push(async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||
let results = futures::future::join_all(futures).await;
|
||||
|
Loading…
Reference in New Issue
Block a user