mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Sending compute type from the environment instead of hardcoded string
Using env is slow, therefore getting it from global state instead.
This commit is contained in:
parent
4c7315dde5
commit
f91fbe9d26
@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||
async fn compat_generate(
|
||||
Extension(default_return_full_text): Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
compute_type: Extension<ComputeType>,
|
||||
Json(mut req): Json<CompatGenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
// default return_full_text given the pipeline_tag
|
||||
@ -66,11 +67,11 @@ async fn compat_generate(
|
||||
|
||||
// switch on stream
|
||||
if req.stream {
|
||||
Ok(generate_stream(infer, Json(req.into()))
|
||||
Ok(generate_stream(infer,compute_type, Json(req.into()))
|
||||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
||||
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
}
|
||||
@ -145,6 +146,7 @@ seed,
|
||||
)]
|
||||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
@ -230,7 +232,7 @@ async fn generate(
|
||||
|
||||
// Headers
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||
headers.insert(
|
||||
"x-compute-time",
|
||||
total_time.as_millis().to_string().parse().unwrap(),
|
||||
@ -339,6 +341,7 @@ seed,
|
||||
)]
|
||||
async fn generate_stream(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> (
|
||||
HeaderMap,
|
||||
@ -349,13 +352,14 @@ async fn generate_stream(
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, Json(req), on_message_callback).await;
|
||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
|
||||
async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
@ -368,7 +372,7 @@ async fn generate_stream_internal(
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||
headers.insert("x-compute-type",compute_type.parse().unwrap());
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
compute_characters.to_string().parse().unwrap(),
|
||||
@ -557,6 +561,7 @@ async fn generate_stream_internal(
|
||||
)]
|
||||
async fn chat_completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
@ -645,12 +650,12 @@ async fn chat_completions(
|
||||
};
|
||||
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) =
|
||||
generate(Extension(infer), Json(generate_request)).await?;
|
||||
generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -729,6 +734,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ComputeType(String);
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
@ -927,6 +935,8 @@ pub async fn run(
|
||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||
};
|
||||
|
||||
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||
|
||||
// Combine routes and layers
|
||||
let app = Router::new()
|
||||
.merge(swagger_ui)
|
||||
@ -936,6 +946,7 @@ pub async fn run(
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(compute_type))
|
||||
.layer(Extension(prom_handle.clone()))
|
||||
.layer(OtelAxumLayer::default())
|
||||
.layer(cors_layer);
|
||||
|
Loading…
Reference in New Issue
Block a user