mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Included model_id in response.
The /generate api call will now return the model name which generated the response.
This commit is contained in:
parent
96a982ad8f
commit
897ed20842
@ -265,6 +265,9 @@ pub(crate) struct GenerateResponse {
|
|||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
|
// Model which generate the text
|
||||||
|
#[schema(example = "bigscience/blomm-560m")]
|
||||||
|
pub model_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -56,6 +56,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
async fn compat_generate(
|
async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
info: Extension<Info>,
|
||||||
Json(mut req): Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
@ -69,7 +70,7 @@ async fn compat_generate(
|
|||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
let (headers, Json(generation)) = generate(infer, info, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
@ -144,6 +145,7 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
info: Extension<Info>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
@ -288,9 +290,13 @@ async fn generate(
|
|||||||
tracing::debug!("Output: {}", output_text);
|
tracing::debug!("Output: {}", output_text);
|
||||||
tracing::info!("Success");
|
tracing::info!("Success");
|
||||||
|
|
||||||
|
let model_id: &str = &info.model_id;
|
||||||
|
let model_id = model_id.to_string();
|
||||||
|
|
||||||
let response = GenerateResponse {
|
let response = GenerateResponse {
|
||||||
generated_text: output_text,
|
generated_text: output_text,
|
||||||
details,
|
details,
|
||||||
|
model_id,
|
||||||
};
|
};
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user