diff --git a/router/src/lib.rs b/router/src/lib.rs index 560b8f74..599f5112 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -265,6 +265,9 @@ pub(crate) struct GenerateResponse { pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, + // Model which generate the text + #[schema(example = "bigscience/blomm-560m")] + pub model_id: String, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8..abb7bd34 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -56,6 +56,7 @@ example = json ! ({"error": "Incomplete generation"})), async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, + info: Extension, Json(mut req): Json, ) -> Result)> { // default return_full_text given the pipeline_tag @@ -69,7 +70,7 @@ async fn compat_generate( .await .into_response()) } else { - let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; + let (headers, Json(generation)) = generate(infer, info, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation])).into_response()) } @@ -144,6 +145,7 @@ seed, )] async fn generate( infer: Extension, + info: Extension, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); @@ -288,9 +290,13 @@ async fn generate( tracing::debug!("Output: {}", output_text); tracing::info!("Success"); + let model_id: &str = &info.model_id; + let model_id = model_id.to_string(); + let response = GenerateResponse { generated_text: output_text, details, + model_id, }; Ok((headers, Json(response))) }