Included model_id in response.

The /generate api call will now return the model name which generated the response.
This commit is contained in:
Suvro Ghosh 2023-10-30 19:23:19 -04:00
parent 96a982ad8f
commit 897ed20842
2 changed files with 10 additions and 1 deletions

View File

@ -265,6 +265,9 @@ pub(crate) struct GenerateResponse {
pub generated_text: String, pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>, pub details: Option<Details>,
// Model which generate the text
#[schema(example = "bigscience/blomm-560m")]
pub model_id: String,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -56,6 +56,7 @@ example = json ! ({"error": "Incomplete generation"})),
async fn compat_generate( async fn compat_generate(
Extension(default_return_full_text): Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
info: Extension<Info>,
Json(mut req): Json<CompatGenerateRequest>, Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag // default return_full_text given the pipeline_tag
@ -69,7 +70,7 @@ async fn compat_generate(
.await .await
.into_response()) .into_response())
} else { } else {
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; let (headers, Json(generation)) = generate(infer, info, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response()) Ok((headers, Json(vec![generation])).into_response())
} }
@ -144,6 +145,7 @@ seed,
)] )]
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
info: Extension<Info>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
@ -288,9 +290,13 @@ async fn generate(
tracing::debug!("Output: {}", output_text); tracing::debug!("Output: {}", output_text);
tracing::info!("Success"); tracing::info!("Success");
let model_id: &str = &info.model_id;
let model_id = model_id.to_string();
let response = GenerateResponse { let response = GenerateResponse {
generated_text: output_text, generated_text: output_text,
details, details,
model_id,
}; };
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }