mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Sending compute type from the environment instead of hardcoded string
Using env is slow, therefore getting it from global state instead.
This commit is contained in:
parent
4c7315dde5
commit
f91fbe9d26
@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
async fn compat_generate(
|
async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
compute_type: Extension<ComputeType>,
|
||||||
Json(mut req): Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
@ -66,11 +67,11 @@ async fn compat_generate(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if req.stream {
|
if req.stream {
|
||||||
Ok(generate_stream(infer, Json(req.into()))
|
Ok(generate_stream(infer,compute_type, Json(req.into()))
|
||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
@ -145,6 +146,7 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
@ -230,7 +232,7 @@ async fn generate(
|
|||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-time",
|
"x-compute-time",
|
||||||
total_time.as_millis().to_string().parse().unwrap(),
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
@ -339,6 +341,7 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
@ -349,13 +352,14 @@ async fn generate_stream(
|
|||||||
event.json_data(stream_token).unwrap()
|
event.json_data(stream_token).unwrap()
|
||||||
};
|
};
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, Json(req), on_message_callback).await;
|
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
(headers, sse)
|
(headers, sse)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_stream_internal(
|
async fn generate_stream_internal(
|
||||||
infer: Infer,
|
infer: Infer,
|
||||||
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
@ -368,7 +372,7 @@ async fn generate_stream_internal(
|
|||||||
let compute_characters = req.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
headers.insert("x-compute-type",compute_type.parse().unwrap());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-characters",
|
"x-compute-characters",
|
||||||
compute_characters.to_string().parse().unwrap(),
|
compute_characters.to_string().parse().unwrap(),
|
||||||
@ -557,6 +561,7 @@ async fn generate_stream_internal(
|
|||||||
)]
|
)]
|
||||||
async fn chat_completions(
|
async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
@ -645,12 +650,12 @@ async fn chat_completions(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
|
generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) =
|
let (headers, Json(generation)) =
|
||||||
generate(Extension(infer), Json(generate_request)).await?;
|
generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -729,6 +734,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct ComputeType(String);
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
@ -927,6 +935,8 @@ pub async fn run(
|
|||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
// Combine routes and layers
|
// Combine routes and layers
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
@ -936,6 +946,7 @@ pub async fn run(
|
|||||||
.layer(Extension(health_ext.clone()))
|
.layer(Extension(health_ext.clone()))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
|
.layer(Extension(compute_type))
|
||||||
.layer(Extension(prom_handle.clone()))
|
.layer(Extension(prom_handle.clone()))
|
||||||
.layer(OtelAxumLayer::default())
|
.layer(OtelAxumLayer::default())
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
Loading…
Reference in New Issue
Block a user