fix: decrease default batch, refactors and include index in batch

This commit is contained in:
drbh 2024-04-11 20:37:35 +00:00
parent 16be5a14b3
commit 908acc55b8
4 changed files with 143 additions and 132 deletions

View File

@ -416,7 +416,7 @@ struct Args {
env: bool,
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "32", long, env)]
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}

View File

@ -293,6 +293,9 @@ mod prompt_serde {
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(vec![s]),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
"Empty array detected. Do not use an empty array for the prompt.",
)),
Value::Array(arr) => arr
.iter()
.map(|v| match v {

View File

@ -78,7 +78,7 @@ struct Args {
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "32", long, env)]
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}

View File

@ -613,10 +613,10 @@ async fn completions(
));
}
let mut generate_requests = Vec::new();
for prompt in req.prompt.iter() {
// build the request passing some parameters
let generate_request = GenerateRequest {
let generate_requests: Vec<GenerateRequest> = req
.prompt
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
parameters: GenerateParameters {
best_of: None,
@ -638,9 +638,8 @@ async fn completions(
top_n_tokens: None,
grammar: None,
},
};
generate_requests.push(generate_request);
}
})
.collect();
let mut x_compute_type = "unknown".to_string();
let mut x_compute_characters = 0u32;
@ -767,23 +766,28 @@ async fn completions(
};
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
return Ok((headers, sse).into_response());
}
Ok((headers, sse).into_response())
} else {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let responses = FuturesUnordered::new();
for generate_request in generate_requests.into_iter() {
responses.push(generate(
Extension(infer.clone()),
Extension(compute_type.clone()),
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let response_future = async move {
let result = generate(
Extension(infer_clone),
Extension(compute_type_clone),
Json(generate_request),
));
)
.await;
result.map(|(headers, generation)| (index, headers, generation))
};
responses.push(response_future);
}
let generate_responses = responses.try_collect::<Vec<_>>().await?;
let mut prompt_tokens = 0u32;
@ -801,8 +805,7 @@ async fn completions(
let choices = generate_responses
.into_iter()
.enumerate()
.map(|(index, (headers, Json(generation)))| {
.map(|(index, headers, Json(generation))| {
let details = generation.details.unwrap_or_default();
if index == 0 {
x_compute_type = headers
@ -868,7 +871,11 @@ async fn completions(
object: "text_completion".to_string(),
created: current_time,
model: info.model_id.clone(),
system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")),
system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
choices,
usage: Usage {
prompt_tokens,
@ -892,6 +899,7 @@ async fn completions(
Ok((headers, Json(response)).into_response())
}
}
/// Generate tokens
#[utoipa::path(