From 897ed20842148222ddedc24007fddda309b0a75d Mon Sep 17 00:00:00 2001 From: Suvro Ghosh Date: Mon, 30 Oct 2023 19:23:19 -0400 Subject: [PATCH] Included model_id in response. The /generate api call will now return the model name which generated the response. --- router/src/lib.rs | 3 +++ router/src/server.rs | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 560b8f74..599f5112 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -265,6 +265,9 @@ pub(crate) struct GenerateResponse { pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, + // Model which generate the text + #[schema(example = "bigscience/blomm-560m")] + pub model_id: String, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8..abb7bd34 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -56,6 +56,7 @@ example = json ! ({"error": "Incomplete generation"})), async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, + info: Extension, Json(mut req): Json, ) -> Result)> { // default return_full_text given the pipeline_tag @@ -69,7 +70,7 @@ async fn compat_generate( .await .into_response()) } else { - let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; + let (headers, Json(generation)) = generate(infer, info, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation])).into_response()) } @@ -144,6 +145,7 @@ seed, )] async fn generate( infer: Extension, + info: Extension, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); @@ -288,9 +290,13 @@ async fn generate( tracing::debug!("Output: {}", output_text); tracing::info!("Success"); + let model_id: &str = &info.model_id; + let model_id = model_id.to_string(); + let response = GenerateResponse { generated_text: output_text, details, + model_id, }; Ok((headers, Json(response))) }