mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: return expected types and append suffix
This commit is contained in:
parent
cade8dbc2b
commit
fa9aad3ec4
@ -306,6 +306,7 @@ pub struct CompletionRequest {
|
|||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
pub suffix: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
@ -4,10 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse,
|
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken,
|
||||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -563,8 +563,8 @@ async fn generate_stream_internal(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn completions(
|
async fn completions(
|
||||||
infer: Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
@ -574,6 +574,7 @@ async fn completions(
|
|||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
let stream = req.stream.unwrap_or_default();
|
let stream = req.stream.unwrap_or_default();
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
let suffix = req.suffix.unwrap_or_default();
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
@ -598,21 +599,103 @@ async fn completions(
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// switch on stream
|
|
||||||
let response = if stream {
|
|
||||||
Ok(
|
|
||||||
generate_stream(infer, compute_type, Json(generate_request.into()))
|
|
||||||
.await
|
|
||||||
.into_response(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let (headers, Json(generation)) =
|
|
||||||
generate(infer, compute_type, Json(generate_request.into())).await?;
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
|
||||||
};
|
|
||||||
|
|
||||||
response
|
if stream {
|
||||||
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
|
let event = Event::default();
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
event
|
||||||
|
.json_data(CompletionCompleteChunk {
|
||||||
|
id: "".to_string(),
|
||||||
|
object: "text_completion".to_string(),
|
||||||
|
created: current_time,
|
||||||
|
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: "".to_string(),
|
||||||
|
index: 0,
|
||||||
|
logprobs: None,
|
||||||
|
text: stream_token.token.text,
|
||||||
|
}],
|
||||||
|
|
||||||
|
model: info.model_id.clone(),
|
||||||
|
system_fingerprint: format!(
|
||||||
|
"{}-{}",
|
||||||
|
info.version,
|
||||||
|
info.docker_label.unwrap_or("native")
|
||||||
|
),
|
||||||
|
})
|
||||||
|
.map_or_else(
|
||||||
|
|e| {
|
||||||
|
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
|
Event::default()
|
||||||
|
},
|
||||||
|
|data| data,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
|
infer,
|
||||||
|
compute_type,
|
||||||
|
Json(generate_request),
|
||||||
|
on_message_callback,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
|
Ok((headers, sse).into_response())
|
||||||
|
} else {
|
||||||
|
let (headers, Json(generation)) = generate(
|
||||||
|
Extension(infer),
|
||||||
|
Extension(compute_type),
|
||||||
|
Json(generate_request),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let details = generation.details.ok_or((
|
||||||
|
// this should never happen but handle if details are missing unexpectedly
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "No details in generation".to_string(),
|
||||||
|
error_type: "no details".to_string(),
|
||||||
|
}),
|
||||||
|
))?;
|
||||||
|
|
||||||
|
let response = Completion {
|
||||||
|
id: "".to_string(),
|
||||||
|
object: "text_completion".to_string(),
|
||||||
|
created: current_time,
|
||||||
|
model: info.model_id.clone(),
|
||||||
|
system_fingerprint: format!(
|
||||||
|
"{}-{}",
|
||||||
|
info.version,
|
||||||
|
info.docker_label.unwrap_or("native")
|
||||||
|
),
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: details.finish_reason.to_string(),
|
||||||
|
index: 0,
|
||||||
|
logprobs: None,
|
||||||
|
text: generation.generated_text + &suffix,
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: details.prefill.len() as u32,
|
||||||
|
completion_tokens: details.generated_tokens,
|
||||||
|
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((headers, Json(response)).into_response())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user