use crate::ffi::{ create_worker_frontend, GenerationParams, LlamaCppWorkerFrontend, SamplingParams, }; use async_trait::async_trait; use cxx::UniquePtr; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Arc; use std::thread::{spawn, JoinHandle}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::{ ValidGenerateRequest, ValidParameters, ValidStoppingParameters, }; use text_generation_router::{FinishReason, Token}; use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info}; type InferResult = Result; unsafe impl Send for LlamaCppWorkerFrontend {} impl From<&ValidParameters> for SamplingParams { fn from(v: &ValidParameters) -> Self { Self { top_k: v.top_k, top_p: v.top_p, frequency_penalty: v.frequency_penalty, repetition_penalty: v.repetition_penalty, seed: v.seed, } } } impl From<&ValidStoppingParameters> for GenerationParams { fn from(v: &ValidStoppingParameters) -> Self { Self { max_new_tokens: v.max_new_tokens, ignore_eos_token: v.ignore_eos_token, } } } #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct GenerationContext { pub(crate) input_tokens: Arc>, pub(crate) generated_tokens: Vec, pub(crate) generation_params: GenerationParams, pub(crate) sampling_params: SamplingParams, } pub(crate) struct InferContext<'a> { pub(crate) start: Instant, pub(crate) stream: UnboundedSender, pub(crate) tokenizer: &'a Tokenizer, pub(crate) generation: GenerationContext, } #[derive(Debug, Error)] pub enum LlamaCppBackendError { #[error("Provided GGUF model path {0} doesn't exist")] ModelFileDoesntExist(String), #[error("Failed to initialize model from GGUF file {0}: {1}")] ModelInitializationFailed(PathBuf, String), } struct LlamaCppWorker { sender: Sender<(GenerationContext, UnboundedSender)>, handle: JoinHandle<()>, } pub enum LlamaCppBackend { Single(LlamaCppWorker), // Multi(Vec) } impl LlamaCppBackend { fn allocate_worker( path: &Path, ) -> Result, LlamaCppBackendError> { create_worker_frontend(&path.display().to_string()).map_err(|ref err| { LlamaCppBackendError::ModelInitializationFailed(path.to_path_buf(), err.to_string()) }) } pub fn new>( model_path: P, tokenizer: Arc, num_cores_per_instance: u16, ) -> Result { let shared_path = Arc::new(model_path); let path = shared_path.deref().as_ref(); if !path.exists() { return Err(LlamaCppBackendError::ModelFileDoesntExist( path.display().to_string(), )); } let worker = match num_cores_per_instance { 0 => { let worker = Self::allocate_worker(path)?; let (sender, receiver) = channel(); let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver)); LlamaCppBackend::Single(LlamaCppWorker { sender, handle }) } _ => panic!("No supported yet"), }; Ok(worker) } } fn llama_generate_callback( ctx: *mut InferContext, new_token_id: u32, new_token_logit: f32, is_final: bool, n_generated_tokens: usize, ) -> bool { debug!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})"); let ctx = unsafe { &mut *ctx }; // Append the new token to the generated ones ctx.generation.generated_tokens.push(new_token_id); // Generate response let response = match ctx.tokenizer.decode(&[new_token_id], false) { Ok(text) => { let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text); let token = Token { id: new_token_id, text, logprob: new_token_logit, special, }; // Should we generate an ending or intermediate response? match is_final { false => Ok(InferStreamResponse::Intermediate { token, top_tokens: vec![], }), true => { // Decode the whole text match ctx .tokenizer .decode(&ctx.generation.generated_tokens, false) { Ok(text) => Ok(InferStreamResponse::End { token, top_tokens: vec![], generated_text: GeneratedText { text, generated_tokens: n_generated_tokens as u32, finish_reason: FinishReason::Length, seed: Some(ctx.generation.sampling_params.seed), }, start: ctx.start, queued: ctx.start, }), Err(err) => Err(InferError::GenerationError(err.to_string())), } } } } Err(ref err) => Err(InferError::GenerationError(err.to_string())), }; // Send back to the client let status = ctx.stream.send(response).inspect_err(|err| { error!("Failed to send back the response: {}", err); }); status.is_err() } fn scheduler_loop( mut backend: UniquePtr, tokenizer: Arc, backlog: Receiver<(GenerationContext, UnboundedSender)>, ) { // This loop will mostly decode single token at every step, so no need to rely on parallelism tokenizers::utils::parallelism::set_parallelism(false); loop { if let Ok((generation, stream)) = backlog.recv() { let start = Instant::now(); let generation_params = generation.generation_params; // copy let sampling_params = generation.sampling_params; // copy let input_tokens = Arc::clone(&generation.input_tokens); // Creating the whole InferContext and pushing it to the heap let ctx = Box::new(InferContext { start, stream, tokenizer: &tokenizer, generation, }); // We leak the box to avoid it being freed after the first callback call // when going out of scope unsafe { let boxed_ctx = Box::into_raw(ctx); if let Err(e) = backend.pin_mut().stream( &input_tokens, generation_params, &sampling_params, boxed_ctx, llama_generate_callback, ) { error!("Error while decoding tokens... {}", e.what()); } // Make sure we re-keep track of the OpaqueStream box let _ = Box::from_raw(boxed_ctx); } } else { info!("IPC channel is closed, exiting the scheduler loop"); break; } } } #[async_trait] impl Backend for LlamaCppBackend { fn schedule( &self, request: ValidGenerateRequest, ) -> Result, InferError> { if let Some(input_ids) = request.input_ids { let (sx, rx) = unbounded_channel(); let sampling_params = SamplingParams::from(&request.parameters); let generation_params = GenerationParams::from(&request.stopping_parameters); let ctx = GenerationContext { input_tokens: Arc::clone(&input_ids), generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize), generation_params, sampling_params, }; match self { LlamaCppBackend::Single(worker) => match worker.sender.send((ctx, sx)) { Ok(_) => Ok(UnboundedReceiverStream::new(rx)), Err(_) => Err(InferError::GenerationError( "Failed to sent the request".to_string(), )), }, } } else { Err(InferError::GenerationError( "Unsupported modalities".to_string(), )) } } async fn health(&self, _: bool) -> bool { true } }