mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: return expected types and append suffix
This commit is contained in:
parent
cade8dbc2b
commit
fa9aad3ec4
@ -306,6 +306,7 @@ pub struct CompletionRequest {
|
||||
pub top_p: Option<f32>,
|
||||
pub stream: Option<bool>,
|
||||
pub seed: Option<u64>,
|
||||
pub suffix: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
|
@ -4,10 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken,
|
||||
SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -563,8 +563,8 @@ async fn generate_stream_internal(
|
||||
)
|
||||
)]
|
||||
async fn completions(
|
||||
infer: Extension<Infer>,
|
||||
compute_type: Extension<ComputeType>,
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
@ -574,6 +574,7 @@ async fn completions(
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let stream = req.stream.unwrap_or_default();
|
||||
let seed = req.seed;
|
||||
let suffix = req.suffix.unwrap_or_default();
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
@ -598,21 +599,103 @@ async fn completions(
|
||||
},
|
||||
};
|
||||
|
||||
// switch on stream
|
||||
let response = if stream {
|
||||
Ok(
|
||||
generate_stream(infer, compute_type, Json(generate_request.into()))
|
||||
.await
|
||||
.into_response(),
|
||||
|
||||
|
||||
if stream {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
})
|
||||
.map_or_else(
|
||||
|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
},
|
||||
|data| data,
|
||||
)
|
||||
} else {
|
||||
let (headers, Json(generation)) =
|
||||
generate(infer, compute_type, Json(generate_request.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
};
|
||||
|
||||
response
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(
|
||||
Extension(infer),
|
||||
Extension(compute_type),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let details = generation.details.ok_or((
|
||||
// this should never happen but handle if details are missing unexpectedly
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "No details in generation".to_string(),
|
||||
error_type: "no details".to_string(),
|
||||
}),
|
||||
))?;
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: generation.generated_text + &suffix,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: details.prefill.len() as u32,
|
||||
completion_tokens: details.generated_tokens,
|
||||
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
|
Loading…
Reference in New Issue
Block a user