Sending compute type from the environment instead of hardcoded string

Using env is slow, therefore getting it from global state instead.
This commit is contained in:
Nicolas Patry 2024-01-29 11:15:58 +01:00
parent 4c7315dde5
commit f91fbe9d26

View File

@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
async fn compat_generate( async fn compat_generate(
Extension(default_return_full_text): Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
Json(mut req): Json<CompatGenerateRequest>, Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag // default return_full_text given the pipeline_tag
@ -66,11 +67,11 @@ async fn compat_generate(
// switch on stream // switch on stream
if req.stream { if req.stream {
Ok(generate_stream(infer, Json(req.into())) Ok(generate_stream(infer,compute_type, Json(req.into()))
.await .await
.into_response()) .into_response())
} else { } else {
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response()) Ok((headers, Json(vec![generation])).into_response())
} }
@ -145,6 +146,7 @@ seed,
)] )]
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
Extension(ComputeType(compute_type)): Extension<ComputeType>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
@ -230,7 +232,7 @@ async fn generate(
// Headers // Headers
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert( headers.insert(
"x-compute-time", "x-compute-time",
total_time.as_millis().to_string().parse().unwrap(), total_time.as_millis().to_string().parse().unwrap(),
@ -339,6 +341,7 @@ seed,
)] )]
async fn generate_stream( async fn generate_stream(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> ( ) -> (
HeaderMap, HeaderMap,
@ -349,13 +352,14 @@ async fn generate_stream(
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, Json(req), on_message_callback).await; generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
async fn generate_stream_internal( async fn generate_stream_internal(
infer: Infer, infer: Infer,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event, on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
@ -368,7 +372,7 @@ async fn generate_stream_internal(
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); headers.insert("x-compute-type",compute_type.parse().unwrap());
headers.insert( headers.insert(
"x-compute-characters", "x-compute-characters",
compute_characters.to_string().parse().unwrap(), compute_characters.to_string().parse().unwrap(),
@ -557,6 +561,7 @@ async fn generate_stream_internal(
)] )]
async fn chat_completions( async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -645,12 +650,12 @@ async fn chat_completions(
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, Json(generate_request), on_message_callback).await; generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = let (headers, Json(generation)) =
generate(Extension(infer), Json(generate_request)).await?; generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -729,6 +734,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
} }
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
@ -927,6 +935,8 @@ pub async fn run(
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
}; };
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
// Combine routes and layers // Combine routes and layers
let app = Router::new() let app = Router::new()
.merge(swagger_ui) .merge(swagger_ui)
@ -936,6 +946,7 @@ pub async fn run(
.layer(Extension(health_ext.clone())) .layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone())) .layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default()) .layer(OtelAxumLayer::default())
.layer(cors_layer); .layer(cors_layer);