mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: avoid unwrap and pre allocate future vec
This commit is contained in:
parent
b5dd58f73b
commit
ed8c7726ba
@ -1407,120 +1407,127 @@ async fn vertex_compatibility(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prepare futures for all instances
|
// Prepare futures for all instances
|
||||||
let futures: Vec<_> = req
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||||||
.instances
|
|
||||||
.iter()
|
for instance in req.instances.iter() {
|
||||||
.map(|instance| {
|
let generate_request = match instance {
|
||||||
let generate_request = match instance {
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||||||
VertexInstance::Generate(instance) => GenerateRequest {
|
inputs: instance.inputs.clone(),
|
||||||
inputs: instance.inputs.clone(),
|
add_special_tokens: true,
|
||||||
add_special_tokens: true,
|
parameters: GenerateParameters {
|
||||||
parameters: GenerateParameters {
|
do_sample: true,
|
||||||
do_sample: true,
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
details: true,
|
||||||
details: true,
|
decoder_input_details: true,
|
||||||
decoder_input_details: true,
|
..Default::default()
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
VertexInstance::Chat(instance) => {
|
},
|
||||||
let ChatRequest {
|
VertexInstance::Chat(instance) => {
|
||||||
model,
|
let ChatRequest {
|
||||||
max_tokens,
|
model,
|
||||||
messages,
|
max_tokens,
|
||||||
seed,
|
messages,
|
||||||
stop,
|
seed,
|
||||||
stream,
|
stop,
|
||||||
tools,
|
stream,
|
||||||
tool_choice,
|
tools,
|
||||||
tool_prompt,
|
tool_choice,
|
||||||
temperature,
|
tool_prompt,
|
||||||
response_format,
|
temperature,
|
||||||
guideline,
|
response_format,
|
||||||
presence_penalty,
|
guideline,
|
||||||
frequency_penalty,
|
presence_penalty,
|
||||||
top_p,
|
frequency_penalty,
|
||||||
top_logprobs,
|
top_p,
|
||||||
..
|
top_logprobs,
|
||||||
} = instance.clone();
|
..
|
||||||
|
} = instance.clone();
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
let tool_prompt = tool_prompt
|
let tool_prompt = tool_prompt
|
||||||
.filter(|s| !s.is_empty())
|
.filter(|s| !s.is_empty())
|
||||||
.unwrap_or_else(default_tool_prompt);
|
.unwrap_or_else(default_tool_prompt);
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
// enable greedy only when temperature is 0
|
// enable greedy only when temperature is 0
|
||||||
let (do_sample, temperature) = match temperature {
|
let (do_sample, temperature) = match temperature {
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
let (inputs, grammar, _using_tools) = prepare_chat_input(
|
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
&tool_prompt,
|
&tool_prompt,
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
)
|
) {
|
||||||
.unwrap();
|
Ok(result) => result,
|
||||||
|
Err(e) => {
|
||||||
// build the request passing some parameters
|
return Err((
|
||||||
GenerateRequest {
|
StatusCode::BAD_REQUEST,
|
||||||
inputs: inputs.to_string(),
|
Json(ErrorResponse {
|
||||||
add_special_tokens: false,
|
error: format!("Failed to prepare chat input: {}", e),
|
||||||
parameters: GenerateParameters {
|
error_type: "Input preparation error".to_string(),
|
||||||
best_of: None,
|
}),
|
||||||
temperature,
|
));
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
top_k: None,
|
|
||||||
top_p,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop,
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: top_logprobs,
|
|
||||||
grammar,
|
|
||||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
GenerateRequest {
|
||||||
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop,
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: top_logprobs,
|
||||||
|
grammar,
|
||||||
|
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let infer_clone = infer.clone();
|
|
||||||
let compute_type_clone = compute_type.clone();
|
|
||||||
let span_clone = span.clone();
|
|
||||||
|
|
||||||
async move {
|
|
||||||
generate_internal(
|
|
||||||
Extension(infer_clone),
|
|
||||||
compute_type_clone,
|
|
||||||
Json(generate_request),
|
|
||||||
span_clone,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
|
||||||
.map_err(|_| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Incomplete generation".into(),
|
|
||||||
error_type: "Incomplete generation".into(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
})
|
};
|
||||||
.collect();
|
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
|
||||||
|
futures.push(async move {
|
||||||
|
generate_internal(
|
||||||
|
Extension(infer_clone),
|
||||||
|
compute_type_clone,
|
||||||
|
Json(generate_request),
|
||||||
|
span_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||||
let results = futures::future::join_all(futures).await;
|
let results = futures::future::join_all(futures).await;
|
||||||
|
Loading…
Reference in New Issue
Block a user