mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: decrease default batch, refactors and include index in batch
This commit is contained in:
parent
16be5a14b3
commit
908acc55b8
@ -416,7 +416,7 @@ struct Args {
|
|||||||
env: bool,
|
env: bool,
|
||||||
|
|
||||||
/// Control the maximum number of inputs that a client can send in a single request
|
/// Control the maximum number of inputs that a client can send in a single request
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,6 +293,9 @@ mod prompt_serde {
|
|||||||
let value = Value::deserialize(deserializer)?;
|
let value = Value::deserialize(deserializer)?;
|
||||||
match value {
|
match value {
|
||||||
Value::String(s) => Ok(vec![s]),
|
Value::String(s) => Ok(vec![s]),
|
||||||
|
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
||||||
|
"Empty array detected. Do not use an empty array for the prompt.",
|
||||||
|
)),
|
||||||
Value::Array(arr) => arr
|
Value::Array(arr) => arr
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| match v {
|
.map(|v| match v {
|
||||||
|
@ -78,7 +78,7 @@ struct Args {
|
|||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -613,10 +613,10 @@ async fn completions(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut generate_requests = Vec::new();
|
let generate_requests: Vec<GenerateRequest> = req
|
||||||
for prompt in req.prompt.iter() {
|
.prompt
|
||||||
// build the request passing some parameters
|
.iter()
|
||||||
let generate_request = GenerateRequest {
|
.map(|prompt| GenerateRequest {
|
||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
@ -638,9 +638,8 @@ async fn completions(
|
|||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
},
|
},
|
||||||
};
|
})
|
||||||
generate_requests.push(generate_request);
|
.collect();
|
||||||
}
|
|
||||||
|
|
||||||
let mut x_compute_type = "unknown".to_string();
|
let mut x_compute_type = "unknown".to_string();
|
||||||
let mut x_compute_characters = 0u32;
|
let mut x_compute_characters = 0u32;
|
||||||
@ -767,23 +766,28 @@ async fn completions(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||||
return Ok((headers, sse).into_response());
|
Ok((headers, sse).into_response())
|
||||||
}
|
} else {
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let responses = FuturesUnordered::new();
|
let responses = FuturesUnordered::new();
|
||||||
for generate_request in generate_requests.into_iter() {
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
responses.push(generate(
|
let infer_clone = infer.clone();
|
||||||
Extension(infer.clone()),
|
let compute_type_clone = compute_type.clone();
|
||||||
Extension(compute_type.clone()),
|
let response_future = async move {
|
||||||
|
let result = generate(
|
||||||
|
Extension(infer_clone),
|
||||||
|
Extension(compute_type_clone),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
));
|
)
|
||||||
|
.await;
|
||||||
|
result.map(|(headers, generation)| (index, headers, generation))
|
||||||
|
};
|
||||||
|
responses.push(response_future);
|
||||||
}
|
}
|
||||||
|
|
||||||
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||||
|
|
||||||
let mut prompt_tokens = 0u32;
|
let mut prompt_tokens = 0u32;
|
||||||
@ -801,8 +805,7 @@ async fn completions(
|
|||||||
|
|
||||||
let choices = generate_responses
|
let choices = generate_responses
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.map(|(index, headers, Json(generation))| {
|
||||||
.map(|(index, (headers, Json(generation)))| {
|
|
||||||
let details = generation.details.unwrap_or_default();
|
let details = generation.details.unwrap_or_default();
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
x_compute_type = headers
|
x_compute_type = headers
|
||||||
@ -868,7 +871,11 @@ async fn completions(
|
|||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
created: current_time,
|
created: current_time,
|
||||||
model: info.model_id.clone(),
|
model: info.model_id.clone(),
|
||||||
system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")),
|
system_fingerprint: format!(
|
||||||
|
"{}-{}",
|
||||||
|
info.version,
|
||||||
|
info.docker_label.unwrap_or("native")
|
||||||
|
),
|
||||||
choices,
|
choices,
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
@ -892,6 +899,7 @@ async fn completions(
|
|||||||
|
|
||||||
Ok((headers, Json(response)).into_response())
|
Ok((headers, Json(response)).into_response())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
|
Loading…
Reference in New Issue
Block a user