Included model_id in response.

The /generate api call will now return the model name which generated the response.
This commit is contained in:
Suvro Ghosh 2023-10-30 19:23:19 -04:00
parent 96a982ad8f
commit 897ed20842
2 changed files with 10 additions and 1 deletions

View File

@ -265,6 +265,9 @@ pub(crate) struct GenerateResponse {
pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>,
// Model which generate the text
#[schema(example = "bigscience/blomm-560m")]
pub model_id: String,
}
#[derive(Serialize, ToSchema)]

View File

@ -56,6 +56,7 @@ example = json ! ({"error": "Incomplete generation"})),
async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag
@ -69,7 +70,7 @@ async fn compat_generate(
.await
.into_response())
} else {
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
let (headers, Json(generation)) = generate(infer, info, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response())
}
@ -144,6 +145,7 @@ seed,
)]
async fn generate(
infer: Extension<Infer>,
info: Extension<Info>,
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
@ -288,9 +290,13 @@ async fn generate(
tracing::debug!("Output: {}", output_text);
tracing::info!("Success");
let model_id: &str = &info.model_id;
let model_id = model_id.to_string();
let response = GenerateResponse {
generated_text: output_text,
details,
model_id,
};
Ok((headers, Json(response)))
}