fix: decrease default batch, refactors and include index in batch

This commit is contained in:
drbh 2024-04-11 20:37:35 +00:00
parent 16be5a14b3
commit 908acc55b8
4 changed files with 143 additions and 132 deletions

View File

@ -416,7 +416,7 @@ struct Args {
env: bool, env: bool,
/// Control the maximum number of inputs that a client can send in a single request /// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "32", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
} }

View File

@ -293,6 +293,9 @@ mod prompt_serde {
let value = Value::deserialize(deserializer)?; let value = Value::deserialize(deserializer)?;
match value { match value {
Value::String(s) => Ok(vec![s]), Value::String(s) => Ok(vec![s]),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
"Empty array detected. Do not use an empty array for the prompt.",
)),
Value::Array(arr) => arr Value::Array(arr) => arr
.iter() .iter()
.map(|v| match v { .map(|v| match v {

View File

@ -78,7 +78,7 @@ struct Args {
messages_api_enabled: bool, messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "32", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
} }

View File

@ -613,10 +613,10 @@ async fn completions(
)); ));
} }
let mut generate_requests = Vec::new(); let generate_requests: Vec<GenerateRequest> = req
for prompt in req.prompt.iter() { .prompt
// build the request passing some parameters .iter()
let generate_request = GenerateRequest { .map(|prompt| GenerateRequest {
inputs: prompt.to_string(), inputs: prompt.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
@ -638,9 +638,8 @@ async fn completions(
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
}, },
}; })
generate_requests.push(generate_request); .collect();
}
let mut x_compute_type = "unknown".to_string(); let mut x_compute_type = "unknown".to_string();
let mut x_compute_characters = 0u32; let mut x_compute_characters = 0u32;
@ -767,23 +766,28 @@ async fn completions(
}; };
let sse = Sse::new(stream).keep_alive(KeepAlive::default()); let sse = Sse::new(stream).keep_alive(KeepAlive::default());
return Ok((headers, sse).into_response()); Ok((headers, sse).into_response())
} } else {
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let responses = FuturesUnordered::new(); let responses = FuturesUnordered::new();
for generate_request in generate_requests.into_iter() { for (index, generate_request) in generate_requests.into_iter().enumerate() {
responses.push(generate( let infer_clone = infer.clone();
Extension(infer.clone()), let compute_type_clone = compute_type.clone();
Extension(compute_type.clone()), let response_future = async move {
let result = generate(
Extension(infer_clone),
Extension(compute_type_clone),
Json(generate_request), Json(generate_request),
)); )
.await;
result.map(|(headers, generation)| (index, headers, generation))
};
responses.push(response_future);
} }
let generate_responses = responses.try_collect::<Vec<_>>().await?; let generate_responses = responses.try_collect::<Vec<_>>().await?;
let mut prompt_tokens = 0u32; let mut prompt_tokens = 0u32;
@ -801,8 +805,7 @@ async fn completions(
let choices = generate_responses let choices = generate_responses
.into_iter() .into_iter()
.enumerate() .map(|(index, headers, Json(generation))| {
.map(|(index, (headers, Json(generation)))| {
let details = generation.details.unwrap_or_default(); let details = generation.details.unwrap_or_default();
if index == 0 { if index == 0 {
x_compute_type = headers x_compute_type = headers
@ -868,7 +871,11 @@ async fn completions(
object: "text_completion".to_string(), object: "text_completion".to_string(),
created: current_time, created: current_time,
model: info.model_id.clone(), model: info.model_id.clone(),
system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")), system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
choices, choices,
usage: Usage { usage: Usage {
prompt_tokens, prompt_tokens,
@ -892,6 +899,7 @@ async fn completions(
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())
} }
}
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(