feat(backend): bind incoming request to the server

This commit is contained in:
Morgan Funtowicz 2024-11-01 00:50:42 +01:00
parent d4aee42fd8
commit 612f2f939f
2 changed files with 129 additions and 31 deletions

View File

@ -2,18 +2,54 @@ use crate::ffi::{
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use cxx::{Exception, UniquePtr}; use cxx::UniquePtr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::mpsc::{channel, Receiver, SendError, Sender};
use std::sync::Arc; use std::sync::Arc;
use std::thread::spawn; use std::thread::{spawn, JoinHandle};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::{
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
};
use text_generation_router::Token;
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::info; use tracing::{error, info};
unsafe impl Send for LlamaCppBackendImpl {} unsafe impl Send for LlamaCppBackendImpl {}
impl From<&ValidParameters> for SamplingParams {
fn from(v: &ValidParameters) -> Self {
Self {
top_k: v.top_k,
top_p: v.top_p,
frequency_penalty: v.frequency_penalty,
repetition_penalty: v.repetition_penalty,
seed: v.seed,
}
}
}
impl From<&ValidStoppingParameters> for GenerationParams {
fn from(v: &ValidStoppingParameters) -> Self {
Self {
max_new_tokens: v.max_new_tokens,
ignore_eos_token: v.ignore_eos_token,
}
}
}
#[cfg_attr(debug_assertions, derive(Debug))]
struct InferContext {
pub(crate) stream: UnboundedSender<Result<InferStreamResponse, InferError>>,
pub(crate) input_tokens: Arc<Vec<u32>>,
pub(crate) generated_tokens: Vec<u32>,
pub(crate) generation_params: GenerationParams,
pub(crate) sampling_params: SamplingParams,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum LlamaCppBackendError { pub enum LlamaCppBackendError {
#[error("Provided GGUF model path {0} doesn't exist")] #[error("Provided GGUF model path {0} doesn't exist")]
@ -23,7 +59,10 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String), ModelInitializationFailed(PathBuf, String),
} }
pub struct LlamaCppBackend {} pub struct LlamaCppBackend {
backlog: Sender<InferContext>,
scheduler_handle: JoinHandle<()>,
}
impl LlamaCppBackend { impl LlamaCppBackend {
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> { pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
@ -34,7 +73,7 @@ impl LlamaCppBackend {
)); ));
} }
let mut backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| { let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
LlamaCppBackendError::ModelInitializationFailed( LlamaCppBackendError::ModelInitializationFailed(
path.to_path_buf(), path.to_path_buf(),
err.what().to_string(), err.what().to_string(),
@ -46,43 +85,100 @@ impl LlamaCppBackend {
path.display() path.display()
); );
let j = spawn(|| scheduler_loop(backend)); let (submitter, receiver) = channel();
j.join().ok(); let handle = spawn(|| scheduler_loop(backend, receiver));
Ok(Self {}) Ok(Self {
backlog: submitter,
scheduler_handle: handle,
})
} }
} }
fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) { fn scheduler_loop(
println!("Scheduler loop"); mut backend: UniquePtr<LlamaCppBackendImpl>,
let tokens = [128000u32, 5159, 836, 374, 23809]; mut backlog: Receiver<InferContext>,
let mut generated = vec![0u32; 16]; ) {
let generation_params = GenerationParams { loop {
max_new_tokens: generated.len() as u32, println!("Looping");
}; if let Ok(mut ctx) = backlog.recv() {
let sampling_params = SamplingParams::default(); println!("{ctx:?}, {}", &ctx.generated_tokens.capacity());
match backend.pin_mut().generate( match backend.pin_mut().generate(
&tokens, &ctx.input_tokens,
&mut generated, &mut ctx.generated_tokens,
&generation_params, &ctx.generation_params,
&sampling_params, &ctx.sampling_params,
|new_token_id: u32, is_eos: bool| println!("Generated {new_token_id} (is_eos: {is_eos})"), |new_token_id: u32, new_token_logit: f32, is_eos: bool| {
let response = InferStreamResponse::Intermediate {
token: Token {
id: new_token_id,
text: "".to_string(),
logprob: new_token_logit,
special: false,
},
top_tokens: vec![],
};
println!("Generated token: {response:?}");
// let _ = tokio::spawn(async {
// match ctx.stream.send(Ok(response)) {
// Ok(_) => {}
// Err(ref err) => {
// error!(
// "Failed to send back token to the client: {}",
// err.to_string()
// );
// }
// }
// });
},
) { ) {
Ok(n_tokens) => { Ok(n_tokens) => {
generated.truncate(n_tokens); unsafe {
println!("Generated {} tokens -> {:?}", n_tokens, generated); ctx.generated_tokens.set_len(n_tokens);
}
println!(
"Generated {} tokens -> {:?}",
n_tokens, &ctx.generated_tokens
);
} }
Err(err) => println!("Error: {}", err), Err(err) => println!("Error: {}", err),
} }
} else {
info!("IPC channel is closed, exiting the scheduler loop");
break;
}
}
} }
#[async_trait] #[async_trait]
impl Backend for LlamaCppBackend { impl Backend for LlamaCppBackend {
fn schedule( fn schedule(
&self, &self,
_request: ValidGenerateRequest, request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Err(InferError::GenerationError("Not implemented yet".into())) if let Some(input_ids) = request.input_ids {
let (sx, rx) = unbounded_channel();
let sampling_params = SamplingParams::from(&request.parameters);
let generation_params = GenerationParams::from(&request.stopping_parameters);
let ctx = InferContext {
stream: sx,
input_tokens: Arc::clone(&input_ids),
generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize),
generation_params,
sampling_params,
};
match self.backlog.send(ctx) {
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
Err(_) => Err(InferError::GenerationError(
"Failed to sent the request".to_string(),
)),
}
} else {
Err(InferError::GenerationError(
"Unsupported modalities".to_string(),
))
}
} }
async fn health(&self, _: bool) -> bool { async fn health(&self, _: bool) -> bool {

View File

@ -16,11 +16,13 @@ impl Default for SamplingParams {
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
mod ffi { mod ffi {
#[derive(Debug, Copy, Clone)]
struct GenerationParams { struct GenerationParams {
max_new_tokens: u32, max_new_tokens: u32,
ignore_eos_token: bool, ignore_eos_token: bool,
} }
#[derive(Debug, Copy, Clone)]
struct SamplingParams { struct SamplingParams {
top_k: u32, top_k: u32,
top_p: f32, top_p: f32,