Sending compute type from the environment instead of hardcoded string

Using env is slow, therefore getting it from global state instead.
This commit is contained in:
Nicolas Patry 2024-01-29 11:15:58 +01:00
parent 4c7315dde5
commit f91fbe9d26

View File

@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag
@ -66,11 +67,11 @@ async fn compat_generate(
// switch on stream
if req.stream {
Ok(generate_stream(infer, Json(req.into()))
Ok(generate_stream(infer,compute_type, Json(req.into()))
.await
.into_response())
} else {
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response())
}
@ -145,6 +146,7 @@ seed,
)]
async fn generate(
infer: Extension<Infer>,
Extension(ComputeType(compute_type)): Extension<ComputeType>,
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
@ -230,7 +232,7 @@ async fn generate(
// Headers
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert(
"x-compute-time",
total_time.as_millis().to_string().parse().unwrap(),
@ -339,6 +341,7 @@ seed,
)]
async fn generate_stream(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<GenerateRequest>,
) -> (
HeaderMap,
@ -349,13 +352,14 @@ async fn generate_stream(
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(req), on_message_callback).await;
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse)
}
async fn generate_stream_internal(
infer: Infer,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
@ -368,7 +372,7 @@ async fn generate_stream_internal(
let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert("x-compute-type",compute_type.parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
@ -557,6 +561,7 @@ async fn generate_stream_internal(
)]
async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -645,12 +650,12 @@ async fn chat_completions(
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) =
generate(Extension(infer), Json(generate_request)).await?;
generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -729,6 +734,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
}
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
@ -927,6 +935,8 @@ pub async fn run(
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
// Combine routes and layers
let app = Router::new()
.merge(swagger_ui)
@ -936,6 +946,7 @@ pub async fn run(
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(cors_layer);