feat: return expected types and append suffix

This commit is contained in:
drbh 2024-02-06 10:31:43 -05:00
parent cade8dbc2b
commit fa9aad3ec4
2 changed files with 104 additions and 20 deletions

View File

@ -306,6 +306,7 @@ pub struct CompletionRequest {
pub top_p: Option<f32>,
pub stream: Option<bool>,
pub seed: Option<u64>,
pub suffix: Option<String>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]

View File

@ -4,10 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken,
SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -563,8 +563,8 @@ async fn generate_stream_internal(
)
)]
async fn completions(
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -574,6 +574,7 @@ async fn completions(
let max_new_tokens = req.max_tokens.or(Some(100));
let stream = req.stream.unwrap_or_default();
let seed = req.seed;
let suffix = req.suffix.unwrap_or_default();
// build the request passing some parameters
let generate_request = GenerateRequest {
@ -598,21 +599,103 @@ async fn completions(
},
};
// switch on stream
let response = if stream {
Ok(
generate_stream(infer, compute_type, Json(generate_request.into()))
.await
.into_response(),
)
} else {
let (headers, Json(generation)) =
generate(infer, compute_type, Json(generate_request.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response())
};
response
if stream {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
event
.json_data(CompletionCompleteChunk {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: "".to_string(),
index: 0,
logprobs: None,
text: stream_token.token.text,
}],
model: info.model_id.clone(),
system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
})
.map_or_else(
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
},
|data| data,
)
};
let (headers, response_stream) = generate_stream_internal(
infer,
compute_type,
Json(generate_request),
on_message_callback,
)
.await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) = generate(
Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "No details in generation".to_string(),
error_type: "no details".to_string(),
}),
))?;
let response = Completion {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
model: info.model_id.clone(),
system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: 0,
logprobs: None,
text: generation.generated_text + &suffix,
}],
usage: Usage {
prompt_tokens: details.prefill.len() as u32,
completion_tokens: details.generated_tokens,
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
},
};
Ok((headers, Json(response)).into_response())
}
}
/// Generate tokens