diff --git a/router/src/lib.rs b/router/src/lib.rs index 44eb6010..494212c5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -306,6 +306,7 @@ pub struct CompletionRequest { pub top_p: Option, pub stream: Option, pub seed: Option, + pub suffix: Option, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] diff --git a/router/src/server.rs b/router/src/server.rs index 89f9d740..f7ab4160 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,10 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, - ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse, - FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, - HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse, + ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, + CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, + GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, + SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -563,8 +563,8 @@ async fn generate_stream_internal( ) )] async fn completions( - infer: Extension, - compute_type: Extension, + Extension(infer): Extension, + Extension(compute_type): Extension, Extension(info): Extension, Json(req): Json, ) -> Result)> { @@ -574,6 +574,7 @@ async fn completions( let max_new_tokens = req.max_tokens.or(Some(100)); let stream = req.stream.unwrap_or_default(); let seed = req.seed; + let suffix = req.suffix.unwrap_or_default(); // build the request passing some parameters let generate_request = GenerateRequest { @@ -598,21 +599,103 @@ async fn completions( }, }; - // switch on stream - let response = if stream { - Ok( - generate_stream(infer, compute_type, Json(generate_request.into())) - .await - .into_response(), - ) - } else { - let (headers, Json(generation)) = - generate(infer, compute_type, Json(generate_request.into())).await?; - // wrap generation inside a Vec to match api-inference - Ok((headers, Json(vec![generation])).into_response()) - }; + - response + if stream { + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(CompletionCompleteChunk { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + + choices: vec![CompletionComplete { + finish_reason: "".to_string(), + index: 0, + logprobs: None, + text: stream_token.token.text, + }], + + model: info.model_id.clone(), + system_fingerprint: format!( + "{}-{}", + info.version, + info.docker_label.unwrap_or("native") + ), + }) + .map_or_else( + |e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }, + |data| data, + ) + }; + + let (headers, response_stream) = generate_stream_internal( + infer, + compute_type, + Json(generate_request), + on_message_callback, + ) + .await; + + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok((headers, sse).into_response()) + } else { + let (headers, Json(generation)) = generate( + Extension(infer), + Extension(compute_type), + Json(generate_request), + ) + .await?; + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let details = generation.details.ok_or(( + // this should never happen but handle if details are missing unexpectedly + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "No details in generation".to_string(), + error_type: "no details".to_string(), + }), + ))?; + + let response = Completion { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + model: info.model_id.clone(), + system_fingerprint: format!( + "{}-{}", + info.version, + info.docker_label.unwrap_or("native") + ), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: 0, + logprobs: None, + text: generation.generated_text + &suffix, + }], + usage: Usage { + prompt_tokens: details.prefill.len() as u32, + completion_tokens: details.generated_tokens, + total_tokens: details.prefill.len() as u32 + details.generated_tokens, + }, + }; + + Ok((headers, Json(response)).into_response()) + } } /// Generate tokens