mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
use IntMap
This commit is contained in:
parent
67ee1907fc
commit
b96fe73beb
7
Cargo.lock
generated
7
Cargo.lock
generated
@ -1087,6 +1087,12 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nohash-hasher"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.1"
|
version = "7.1.1"
|
||||||
@ -1826,6 +1832,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"futures",
|
"futures",
|
||||||
|
"nohash-hasher",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
|||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
serde = "1.0.145"
|
serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::{Db, Entry};
|
use crate::{Db, Entry};
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
use crate::{ErrorResponse, GenerateRequest};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{oneshot, Notify};
|
||||||
use tokio::sync::{Semaphore, TryAcquireError};
|
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -26,8 +25,6 @@ pub struct Batcher {
|
|||||||
struct Shared {
|
struct Shared {
|
||||||
/// Batching background Tokio task notifier
|
/// Batching background Tokio task notifier
|
||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
/// Inference request limit
|
|
||||||
limit_concurrent_requests: Semaphore,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Batcher {
|
impl Batcher {
|
||||||
@ -35,13 +32,11 @@ impl Batcher {
|
|||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Batcher shared state
|
// Batcher shared state
|
||||||
let db = Db::new();
|
let db = Db::new();
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
@ -62,9 +57,6 @@ impl Batcher {
|
|||||||
input_length: usize,
|
input_length: usize,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
|
|
||||||
|
|
||||||
// One shot channel to communicate with the background batching task
|
// One shot channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
|
|
||||||
@ -151,8 +143,7 @@ async fn batching_task(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_batch =
|
cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await;
|
||||||
wrap_future(client.generate_with_cache(batches), &mut entries).await;
|
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,7 +153,7 @@ async fn batching_task(
|
|||||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||||
entries: &mut HashMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
Ok((generated_texts, next_batch)) => {
|
Ok((generated_texts, next_batch)) => {
|
||||||
@ -178,7 +169,7 @@ async fn wrap_future(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Send errors to the Batcher for all `entries`
|
/// Send errors to the Batcher for all `entries`
|
||||||
fn send_error(error: ClientError, entries: &mut HashMap<u64, Entry>) {
|
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
entries.drain().for_each(|(_, entry)| {
|
entries.drain().for_each(|(_, entry)| {
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||||
@ -186,7 +177,7 @@ fn send_error(error: ClientError, entries: &mut HashMap<u64, Entry>) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Send `generated_text` to the Batcher for all `finished`
|
/// Send `generated_text` to the Batcher for all `finished`
|
||||||
fn send_generated(finished: Vec<GeneratedText>, entries: &mut HashMap<u64, Entry>) {
|
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
// We can `expect` here as the request id should always be in the entries
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
@ -226,30 +217,18 @@ pub(crate) struct InferResponse {
|
|||||||
pub enum InferError {
|
pub enum InferError {
|
||||||
#[error("Request failed during generation: {0}")]
|
#[error("Request failed during generation: {0}")]
|
||||||
GenerationError(String),
|
GenerationError(String),
|
||||||
#[error("Model is overloaded")]
|
|
||||||
Overloaded,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert semaphore error
|
|
||||||
impl From<TryAcquireError> for InferError {
|
|
||||||
fn from(_: TryAcquireError) -> Self {
|
|
||||||
InferError::Overloaded
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert to Axum supported format
|
/// Convert to Axum supported format
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
let status_code = match err {
|
match err {
|
||||||
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
InferError::GenerationError(_) => (
|
||||||
InferError::Overloaded => StatusCode::TOO_MANY_REQUESTS,
|
StatusCode::FAILED_DEPENDENCY,
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
status_code,
|
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
error: err.to_string(),
|
error: err.to_string(),
|
||||||
}),
|
}),
|
||||||
)
|
),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
use crate::InferResponse;
|
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
|
use crate::InferResponse;
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::{BTreeMap, HashMap};
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
@ -117,7 +118,7 @@ impl Db {
|
|||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
max_size: usize,
|
max_size: usize,
|
||||||
) -> Option<(HashMap<u64, Entry>, Batch)> {
|
) -> Option<(IntMap<u64, Entry>, Batch)> {
|
||||||
// Acquire lock
|
// Acquire lock
|
||||||
let mut state = self.shared.state.lock();
|
let mut state = self.shared.state.lock();
|
||||||
|
|
||||||
@ -132,13 +133,13 @@ impl Db {
|
|||||||
// Batch size
|
// Batch size
|
||||||
let size = requests.len();
|
let size = requests.len();
|
||||||
|
|
||||||
let mut entries = HashMap::with_capacity(size);
|
let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default());
|
||||||
ids.iter().for_each(|id| {
|
ids.iter().for_each(|id| {
|
||||||
// Remove entry from db
|
// Remove entry from db
|
||||||
let mut entry = state.entries.remove(id).unwrap();
|
let mut entry = state.entries.remove(id).unwrap();
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
// Insert in entries hashmap
|
// Insert in entries IntMap
|
||||||
entries.insert(*id, entry);
|
entries.insert(*id, entry);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -7,9 +7,11 @@ use axum::response::IntoResponse;
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -18,6 +20,7 @@ use tracing::instrument;
|
|||||||
struct ServerState {
|
struct ServerState {
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
batcher: Batcher,
|
batcher: Batcher,
|
||||||
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
@ -27,6 +30,16 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead if check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Model is overloaded".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
// Send a small inference request
|
// Send a small inference request
|
||||||
state
|
state
|
||||||
.batcher
|
.batcher
|
||||||
@ -65,6 +78,16 @@ async fn generate(
|
|||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
|
tracing::error!("Model is overloaded");
|
||||||
|
(
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Model is overloaded".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
@ -162,16 +185,12 @@ pub async fn run(
|
|||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
// Create state
|
// Create state
|
||||||
let batcher = Batcher::new(
|
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
||||||
client,
|
|
||||||
max_batch_size,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_concurrent_requests,
|
|
||||||
);
|
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let shared_state = ServerState {
|
let shared_state = ServerState {
|
||||||
validation,
|
validation,
|
||||||
batcher,
|
batcher,
|
||||||
|
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
|
Loading…
Reference in New Issue
Block a user