diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..723ef36f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/k6/load_test.js b/k6/load_test.js new file mode 100644 index 00000000..3f3791af --- /dev/null +++ b/k6/load_test.js @@ -0,0 +1,97 @@ +import http from 'k6/http'; +import {check, sleep} from 'k6'; + +export const options = { + stages: [ + {duration: '1m', target: 50}, + {duration: '2m', target: 100}, + {duration: '1m', target: 0}, + ], + hosts: { + 'text-generation-inference.huggingface.co': '127.0.0.1:3000', + }, +}; +const SLEEP_DURATION = 1; + +function greedy_example(inputs, max_new_tokens, name) { + let body = JSON.stringify({ + inputs: inputs, + parameters: { + max_new_tokens: max_new_tokens, + do_sample: false, + } + }); + let params = { + headers: { + 'Content-Type': 'application/json', + }, + tags: { + name: name + } + }; + return http.post('http://text-generation-inference.huggingface.co/generate', body, params); +} + +function sample_example(inputs, max_new_tokens, name) { + let body = JSON.stringify({ + inputs: inputs, + parameters: { + max_new_tokens: max_new_tokens, + do_sample: true, + top_p: 0.9 + } + }); + let params = { + headers: { + 'Content-Type': 'application/json', + }, + tags: { + name: name + } + }; + return http.post('http://text-generation-inference.huggingface.co/generate', body, params); +} + +export default function () { + const response_1 = sample_example('A "whatpu" is a small, furry animal native to Tanzania. An example of a sentence that uses the word whatpu is: We were traveling in Africa and we saw these very cute whatpus. To do a "farduddle" means to jump up and down really fast. An example of a sentence that uses the word farduddle is:', 32, 'example-1'); + check(response_1, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); + + const response_2 = sample_example("A poem about the beauty of science by Alfred Edgar Brittle\\nTitle: The Magic Craft\\nIn the old times", 50, "example-2"); + check(response_2, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); + + const response_3 = greedy_example("استخراج العدد العاملي في لغة بايثون: ", 30, "example-3"); + check(response_3, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); + + const response_4 = sample_example("Pour déguster un ortolan, il faut tout d'abord", 32, "example-4"); + check(response_4, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); + + const response_5 = sample_example("Traduce español de España a español de Argentina\nEl coche es rojo - el auto es rojo\nEl ordenador es nuevo - la computadora es nueva\nel boligrafo es negro -", 16, "example-5"); + check(response_5, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); + + const response_6 = sample_example("Question: If I put cheese into the fridge, will it melt?\nAnswer:", 32, "example-6"); + check(response_6, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); + + const response_7 = greedy_example("Question: Where does the Greek Goddess Persephone spend half of the year when she is not with her mother?\nAnswer:", 24, "example-7"); + check(response_7, { + 'is status 200': (r) => r.status === 200, + }); + sleep(SLEEP_DURATION); +} \ No newline at end of file diff --git a/router/src/infer.rs b/router/src/batcher.rs similarity index 69% rename from router/src/infer.rs rename to router/src/batcher.rs index 2a7aa0ac..2da47dfc 100644 --- a/router/src/infer.rs +++ b/router/src/batcher.rs @@ -1,14 +1,17 @@ -use crate::{Db, GenerateRequest}; -use bloom_inference_client::{Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient}; +use crate::Db; +use bloom_inference_client::{ + Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient, +}; use std::sync::Arc; -use tokio::sync::{oneshot, Notify}; +use tokio::sync::{Notify, oneshot}; +use crate::server::GenerateRequest; const MAX_LENGTH: usize = 128; pub struct InferError {} #[derive(Clone)] -pub(crate) struct Infer { +pub(crate) struct Batcher { db: Db, shared: Arc, } @@ -17,7 +20,7 @@ struct Shared { batching_task: Notify, } -impl Infer { +impl Batcher { pub(crate) fn new(client: ShardedClient) -> Self { let db = Db::new(); let shared = Arc::new(Shared { @@ -38,7 +41,7 @@ impl Infer { self.shared.batching_task.notify_waiters(); match request_rx.await.unwrap() { Ok(output) => Ok(output), - Err(_) => Err(InferError {}) + Err(_) => Err(InferError {}), } } } @@ -57,19 +60,19 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc) { let mut max_sequence_length = entry.sequence_length; let mut request_ids = entry.request_ids; - if total_batch_size <= 16 { - if let Some(batch) = db.next_batch_minimum_size(16, 48) { - let other_cache_entry = infer_batch(batch, &client, &db).await; - - if let Some(entry) = other_cache_entry { - batch_cached_ids.push(entry.id); - total_batch_size += entry.request_ids.len(); - max_sequence_length = - max_sequence_length.max(entry.sequence_length); - request_ids.extend(entry.request_ids.into_iter()); - } - } - } + // if total_batch_size <= 16 { + // if let Some(batch) = db.next_batch_minimum_size(16, 48) { + // let other_cache_entry = infer_batch(batch, &client, &db).await; + // + // if let Some(entry) = other_cache_entry { + // batch_cached_ids.push(entry.id); + // total_batch_size += entry.request_ids.len(); + // max_sequence_length = + // max_sequence_length.max(entry.sequence_length); + // request_ids.extend(entry.request_ids.into_iter()); + // } + // } + // } let batch_cached = BatchCached { id: entry.id, @@ -87,7 +90,11 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc) { } } -async fn infer_batch_cached(batch: BatchCached, client: &ShardedClient, db: &Db) -> Option { +async fn infer_batch_cached( + batch: BatchCached, + client: &ShardedClient, + db: &Db, +) -> Option { match client.generate_with_cache(batch.clone()).await { Ok((finished, cache_entry)) => { send_finished(finished, db); @@ -109,7 +116,11 @@ async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option { println!("{:?}", err); - send_error(err, batch.requests.into_iter().map(|req| req.id).collect(), &db); + send_error( + err, + batch.requests.into_iter().map(|req| req.id).collect(), + &db, + ); None } } diff --git a/router/src/db.rs b/router/src/db.rs index 3dd98d94..b6d218e2 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,5 +1,5 @@ /// This code is massively inspired by Tokio mini-redis -use crate::GenerateRequest; +use crate::server::GenerateRequest; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use parking_lot::RwLock; use std::collections::BTreeMap; @@ -44,7 +44,11 @@ impl Db { Self { shared } } - pub(crate) fn append(&self, request: GenerateRequest, sender: Sender>) { + pub(crate) fn append( + &self, + request: GenerateRequest, + sender: Sender>, + ) { let mut state = self.shared.state.write(); let id = state.next_id; @@ -65,7 +69,10 @@ impl Db { state.entries.insert(id, (request, sender)); } - pub(crate) fn remove(&self, id: &u64) -> Option<(Request, Sender>)> { + pub(crate) fn remove( + &self, + id: &u64, + ) -> Option<(Request, Sender>)> { let mut state = self.shared.state.write(); state.entries.remove(id) } diff --git a/router/src/main.rs b/router/src/main.rs index 5d87cd46..97ccb571 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,105 +1,15 @@ -use tokio::time::Instant; - -use poem; -use poem::middleware::AddData; -use poem::web::Data; -use poem::{handler, listener::TcpListener, post, web::Json, EndpointExt, Result, Route, Server}; - use bloom_inference_client::ShardedClient; -use serde::Deserialize; +use poem; +use poem::listener::TcpListener; use std::time::Duration; -use poem::http::StatusCode; -use tracing::instrument; + +mod server; mod db; - use db::Db; -mod infer; - -use infer::Infer; - -#[derive(Clone, Debug, Deserialize)] -struct GenerateParameters { - #[serde(default = "default_temperature")] - temperature: f32, - #[serde(default = "default_top_k")] - top_k: u32, - #[serde(default = "default_top_p")] - top_p: f32, - #[serde(default = "default_do_sample")] - do_sample: bool, - #[serde(default = "default_max_new_tokens")] - max_new_tokens: u32, -} - -fn default_temperature() -> f32 { - 1.0 -} - -fn default_top_k() -> u32 { - 0 -} - -fn default_top_p() -> f32 { - 1.0 -} - -fn default_do_sample() -> bool { - false -} - -fn default_max_new_tokens() -> u32 { - 20 -} - -#[derive(Clone, Debug, Deserialize)] -struct GenerateRequest { - inputs: String, - #[serde(default = "default_parameters")] - parameters: GenerateParameters, -} - -fn default_parameters() -> GenerateParameters { - GenerateParameters { - temperature: default_temperature(), - top_k: default_top_k(), - top_p: default_top_p(), - do_sample: default_do_sample(), - max_new_tokens: default_max_new_tokens(), - } -} - -#[handler] -#[instrument(skip(infer), fields(time, time_per_token))] -async fn generate( - infer: Data<&Infer>, - req: Json, -) -> Result> { - let start = Instant::now(); - - let output = infer - .infer(GenerateRequest { - inputs: req.inputs.clone(), - parameters: req.parameters.clone(), - }) - .await; - - match output { - Ok(generated_text) => { - tracing::Span::current().record("time", format!("{:?}", start.elapsed())); - tracing::Span::current().record("time_per_token", format!("{:?}", start.elapsed() / req.parameters.max_new_tokens)); - tracing::info!("response: {}", generated_text); - - Ok(Json(serde_json::json!({ - "generated_text": generated_text, - }))) - } - Err(_) => { - Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)) - } - } -} +mod batcher; +use batcher::Batcher; #[tokio::main] async fn main() -> Result<(), std::io::Error> { @@ -114,12 +24,8 @@ async fn main() -> Result<(), std::io::Error> { .expect("Unable to clear cache"); tracing::info!("Connected"); - let infer = Infer::new(sharded_client); + let addr = "127.0.0.1:3000".to_string(); + let listener = TcpListener::bind(addr); - let app = Route::new() - .at("/generate", post(generate)) - .with(AddData::new(infer)); - Server::new(TcpListener::bind("127.0.0.1:3000")) - .run(app) - .await + server::run(sharded_client, listener).await } diff --git a/router/src/server.rs b/router/src/server.rs new file mode 100644 index 00000000..0daf8df3 --- /dev/null +++ b/router/src/server.rs @@ -0,0 +1,111 @@ +use poem::{EndpointExt, handler, post, Route, Server}; +use poem::http::StatusCode; +use poem::listener::TcpListener; +use poem::middleware::AddData; +use poem::web::{Data, Json}; +use tokio::time::Instant; +use crate::{Batcher, ShardedClient}; +use tracing::instrument; +use serde::Deserialize; + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct GenerateParameters { + #[serde(default = "default_temperature")] + pub temperature: f32, + #[serde(default = "default_top_k")] + pub top_k: u32, + #[serde(default = "default_top_p")] + pub top_p: f32, + #[serde(default = "default_do_sample")] + pub do_sample: bool, + #[serde(default = "default_max_new_tokens")] + pub max_new_tokens: u32, +} + +fn default_temperature() -> f32 { + 1.0 +} + +fn default_top_k() -> u32 { + 0 +} + +fn default_top_p() -> f32 { + 1.0 +} + +fn default_do_sample() -> bool { + false +} + +fn default_max_new_tokens() -> u32 { + 20 +} + +fn default_parameters() -> GenerateParameters { + GenerateParameters { + temperature: default_temperature(), + top_k: default_top_k(), + top_p: default_top_p(), + do_sample: default_do_sample(), + max_new_tokens: default_max_new_tokens(), + } +} + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct GenerateRequest { + pub inputs: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, +} + + +#[handler] +#[instrument(skip(infer), fields(time, time_per_token))] +async fn generate( + infer: Data<&Batcher>, + req: Json, +) -> poem::Result> { + let start = Instant::now(); + + let output = infer + .infer(GenerateRequest { + inputs: req.inputs.clone(), + parameters: req.parameters.clone(), + }) + .await; + + match output { + Ok(generated_text) => { + tracing::Span::current().record("time", format!("{:?}", start.elapsed())); + tracing::Span::current().record( + "time_per_token", + format!("{:?}", start.elapsed() / req.parameters.max_new_tokens), + ); + tracing::info!("response: {}", generated_text); + + Ok(Json(serde_json::json!({ + "generated_text": generated_text, + }))) + } + Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)), + } +} + +pub async fn run(client: ShardedClient, listener: TcpListener) -> Result<(), std::io::Error> { + client + .clear_cache() + .await + .expect("Unable to clear cache"); + tracing::info!("Connected"); + + let infer = Batcher::new(client); + + let app = Route::new() + .at("/generate", post(generate)) + .with(AddData::new(infer)); + + Server::new(listener) + .run(app) + .await +} \ No newline at end of file