2024-10-31 16:51:57 +00:00
|
|
|
use crate::ffi::{
|
2024-11-09 21:10:33 +00:00
|
|
|
create_worker_frontend, GenerationParams, LlamaCppWorkerFrontend, SamplingParams,
|
2024-10-31 16:51:57 +00:00
|
|
|
};
|
2024-10-24 14:42:50 +00:00
|
|
|
use async_trait::async_trait;
|
2024-10-31 23:50:42 +00:00
|
|
|
use cxx::UniquePtr;
|
2024-11-09 21:10:33 +00:00
|
|
|
use std::ops::Deref;
|
2024-10-24 14:42:50 +00:00
|
|
|
use std::path::{Path, PathBuf};
|
2024-11-02 23:53:34 +00:00
|
|
|
use std::sync::mpsc::{channel, Receiver, Sender};
|
2024-10-24 14:42:50 +00:00
|
|
|
use std::sync::Arc;
|
2024-10-31 23:50:42 +00:00
|
|
|
use std::thread::{spawn, JoinHandle};
|
2024-11-02 23:36:32 +00:00
|
|
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
2024-10-31 23:50:42 +00:00
|
|
|
use text_generation_router::validation::{
|
|
|
|
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
|
|
|
};
|
2024-11-02 23:36:32 +00:00
|
|
|
use text_generation_router::{FinishReason, Token};
|
2024-10-24 14:42:50 +00:00
|
|
|
use thiserror::Error;
|
2024-11-04 22:01:57 +00:00
|
|
|
use tokenizers::Tokenizer;
|
2024-10-31 23:50:42 +00:00
|
|
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
2024-11-02 23:36:32 +00:00
|
|
|
use tokio::time::Instant;
|
2024-10-04 08:42:31 +00:00
|
|
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
2024-11-04 22:24:50 +00:00
|
|
|
use tracing::{debug, error, info};
|
2024-10-04 08:42:31 +00:00
|
|
|
|
2024-11-04 15:17:43 +00:00
|
|
|
type InferResult = Result<InferStreamResponse, InferError>;
|
2024-11-02 23:36:32 +00:00
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
unsafe impl Send for LlamaCppWorkerFrontend {}
|
2024-10-24 14:42:50 +00:00
|
|
|
|
2024-10-31 23:50:42 +00:00
|
|
|
impl From<&ValidParameters> for SamplingParams {
|
|
|
|
fn from(v: &ValidParameters) -> Self {
|
|
|
|
Self {
|
|
|
|
top_k: v.top_k,
|
|
|
|
top_p: v.top_p,
|
|
|
|
frequency_penalty: v.frequency_penalty,
|
|
|
|
repetition_penalty: v.repetition_penalty,
|
|
|
|
seed: v.seed,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl From<&ValidStoppingParameters> for GenerationParams {
|
|
|
|
fn from(v: &ValidStoppingParameters) -> Self {
|
|
|
|
Self {
|
|
|
|
max_new_tokens: v.max_new_tokens,
|
|
|
|
ignore_eos_token: v.ignore_eos_token,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg_attr(debug_assertions, derive(Debug))]
|
2024-11-04 16:01:22 +00:00
|
|
|
pub(crate) struct GenerationContext {
|
2024-10-31 23:50:42 +00:00
|
|
|
pub(crate) input_tokens: Arc<Vec<u32>>,
|
|
|
|
pub(crate) generated_tokens: Vec<u32>,
|
|
|
|
pub(crate) generation_params: GenerationParams,
|
|
|
|
pub(crate) sampling_params: SamplingParams,
|
|
|
|
}
|
|
|
|
|
2024-11-04 15:17:43 +00:00
|
|
|
pub(crate) struct InferContext {
|
|
|
|
pub(crate) start: Instant,
|
|
|
|
pub(crate) stream: UnboundedSender<InferResult>,
|
2024-11-04 22:01:57 +00:00
|
|
|
pub(crate) tokenizer: Tokenizer,
|
2024-11-04 15:17:43 +00:00
|
|
|
pub(crate) generation: GenerationContext,
|
|
|
|
}
|
|
|
|
|
2024-10-24 14:42:50 +00:00
|
|
|
#[derive(Debug, Error)]
|
|
|
|
pub enum LlamaCppBackendError {
|
|
|
|
#[error("Provided GGUF model path {0} doesn't exist")]
|
|
|
|
ModelFileDoesntExist(String),
|
|
|
|
|
|
|
|
#[error("Failed to initialize model from GGUF file {0}: {1}")]
|
|
|
|
ModelInitializationFailed(PathBuf, String),
|
2024-10-24 07:56:40 +00:00
|
|
|
}
|
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
// pub struct LlamaCppBackend {
|
|
|
|
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
|
|
|
// _scheduler_handle: JoinHandle<()>,
|
|
|
|
// }
|
|
|
|
|
|
|
|
struct LlamaCppWorker {
|
|
|
|
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
|
|
|
handle: JoinHandle<()>,
|
|
|
|
}
|
|
|
|
|
|
|
|
pub enum LlamaCppBackend {
|
|
|
|
Single(LlamaCppWorker),
|
|
|
|
// Multi(Vec<LlamaCppWorker>)
|
2024-10-31 23:50:42 +00:00
|
|
|
}
|
2024-10-24 14:42:50 +00:00
|
|
|
|
|
|
|
impl LlamaCppBackend {
|
2024-11-09 21:10:33 +00:00
|
|
|
fn allocate_worker(
|
|
|
|
path: &Path,
|
|
|
|
) -> Result<UniquePtr<LlamaCppWorkerFrontend>, LlamaCppBackendError> {
|
|
|
|
create_worker_frontend(&path.display().to_string()).map_err(|ref err| {
|
|
|
|
LlamaCppBackendError::ModelInitializationFailed(path.to_path_buf(), err.to_string())
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn new<P: AsRef<Path>>(
|
2024-11-04 22:01:57 +00:00
|
|
|
model_path: P,
|
|
|
|
tokenizer: Tokenizer,
|
2024-11-09 21:10:33 +00:00
|
|
|
num_cores_per_instance: u16,
|
2024-11-04 22:01:57 +00:00
|
|
|
) -> Result<Self, LlamaCppBackendError> {
|
2024-11-09 21:10:33 +00:00
|
|
|
let shared_path = Arc::new(model_path);
|
|
|
|
let path = shared_path.deref().as_ref();
|
2024-10-24 14:42:50 +00:00
|
|
|
if !path.exists() {
|
|
|
|
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
|
|
|
path.display().to_string(),
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
let worker = match num_cores_per_instance {
|
|
|
|
0 => {
|
|
|
|
let worker = Self::allocate_worker(path)?;
|
|
|
|
let (sender, receiver) = channel();
|
|
|
|
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
|
|
|
|
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
|
|
|
|
}
|
|
|
|
_ => panic!("No supported yet"),
|
|
|
|
};
|
2024-10-24 14:42:50 +00:00
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
Ok(worker)
|
2024-10-24 07:56:40 +00:00
|
|
|
}
|
|
|
|
}
|
2024-10-04 08:42:31 +00:00
|
|
|
|
2024-11-02 23:36:32 +00:00
|
|
|
fn llama_generate_callback(
|
2024-11-04 15:17:43 +00:00
|
|
|
ctx: *mut InferContext,
|
2024-11-02 23:36:32 +00:00
|
|
|
new_token_id: u32,
|
|
|
|
new_token_logit: f32,
|
2024-11-04 15:17:43 +00:00
|
|
|
is_final: bool,
|
2024-11-03 22:07:22 +00:00
|
|
|
n_generated_tokens: usize,
|
2024-11-04 16:01:22 +00:00
|
|
|
) -> bool {
|
2024-11-04 22:24:50 +00:00
|
|
|
debug!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
2024-11-04 15:17:43 +00:00
|
|
|
|
|
|
|
let ctx = unsafe { &mut *ctx };
|
|
|
|
|
|
|
|
// Append the new token to the generated ones
|
|
|
|
ctx.generation.generated_tokens.push(new_token_id);
|
|
|
|
|
2024-11-12 23:22:11 +00:00
|
|
|
// Generate response
|
|
|
|
let response = match ctx.tokenizer.decode(&[new_token_id], false) {
|
2024-11-04 22:01:57 +00:00
|
|
|
Ok(text) => {
|
|
|
|
let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text);
|
2024-11-12 23:22:11 +00:00
|
|
|
let token = Token {
|
2024-11-04 22:01:57 +00:00
|
|
|
id: new_token_id,
|
|
|
|
text,
|
|
|
|
logprob: new_token_logit,
|
|
|
|
special,
|
2024-11-12 23:22:11 +00:00
|
|
|
};
|
2024-11-04 22:01:57 +00:00
|
|
|
|
2024-11-12 23:22:11 +00:00
|
|
|
// Should we generate an ending or intermediate response?
|
2024-11-06 16:46:46 +00:00
|
|
|
match is_final {
|
|
|
|
false => Ok(InferStreamResponse::Intermediate {
|
2024-11-04 22:01:57 +00:00
|
|
|
token,
|
|
|
|
top_tokens: vec![],
|
2024-11-06 16:46:46 +00:00
|
|
|
}),
|
|
|
|
true => {
|
|
|
|
// Decode the whole text
|
|
|
|
match ctx
|
|
|
|
.tokenizer
|
|
|
|
.decode(&ctx.generation.generated_tokens, false)
|
|
|
|
{
|
|
|
|
Ok(text) => Ok(InferStreamResponse::End {
|
|
|
|
token,
|
|
|
|
top_tokens: vec![],
|
|
|
|
generated_text: GeneratedText {
|
|
|
|
text,
|
|
|
|
generated_tokens: n_generated_tokens as u32,
|
|
|
|
finish_reason: FinishReason::Length,
|
|
|
|
seed: Some(ctx.generation.sampling_params.seed),
|
|
|
|
},
|
|
|
|
start: ctx.start,
|
|
|
|
queued: ctx.start,
|
|
|
|
}),
|
|
|
|
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
|
|
|
}
|
|
|
|
}
|
2024-11-04 15:17:43 +00:00
|
|
|
}
|
|
|
|
}
|
2024-11-12 23:22:11 +00:00
|
|
|
Err(ref err) => Err(InferError::GenerationError(err.to_string())),
|
2024-11-02 23:36:32 +00:00
|
|
|
};
|
2024-11-04 15:17:43 +00:00
|
|
|
|
|
|
|
// Send back to the client
|
2024-11-12 23:22:11 +00:00
|
|
|
let status = ctx.stream.send(response).inspect_err(|err| {
|
|
|
|
error!("Failed to send back the response: {}", err);
|
|
|
|
});
|
|
|
|
status.is_err()
|
2024-11-02 23:36:32 +00:00
|
|
|
}
|
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
fn scheduler_loop(
|
|
|
|
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
2024-11-04 22:01:57 +00:00
|
|
|
tokenizer: Tokenizer,
|
2024-11-04 15:17:43 +00:00
|
|
|
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
2024-10-31 23:50:42 +00:00
|
|
|
) {
|
2024-11-04 15:17:43 +00:00
|
|
|
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
|
|
|
tokenizers::utils::parallelism::set_parallelism(false);
|
|
|
|
|
2024-10-31 23:50:42 +00:00
|
|
|
loop {
|
2024-11-04 15:17:43 +00:00
|
|
|
if let Ok((generation, stream)) = backlog.recv() {
|
2024-11-02 23:46:04 +00:00
|
|
|
let start = Instant::now();
|
2024-11-04 22:01:57 +00:00
|
|
|
let tokenizer = tokenizer.clone();
|
2024-11-04 15:17:43 +00:00
|
|
|
let generation_params = generation.generation_params; // copy
|
|
|
|
let sampling_params = generation.sampling_params; // copy
|
|
|
|
let input_tokens = Arc::clone(&generation.input_tokens);
|
|
|
|
|
|
|
|
// Creating the whole InferContext and pushing it to the heap
|
|
|
|
{
|
|
|
|
let ctx = Box::new(InferContext {
|
|
|
|
start,
|
|
|
|
stream,
|
2024-11-04 22:01:57 +00:00
|
|
|
tokenizer,
|
2024-11-04 15:17:43 +00:00
|
|
|
generation,
|
|
|
|
});
|
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
// We leak the box to avoid it being freed after the first callback call
|
|
|
|
// when going out of scope
|
|
|
|
unsafe {
|
|
|
|
let boxed_ctx = Box::into_raw(ctx);
|
|
|
|
if let Err(e) = backend.pin_mut().stream(
|
|
|
|
&input_tokens,
|
|
|
|
generation_params,
|
|
|
|
&sampling_params,
|
|
|
|
boxed_ctx,
|
|
|
|
llama_generate_callback,
|
|
|
|
) {
|
|
|
|
error!("Error while decoding tokens... {}", e.what());
|
|
|
|
}
|
2024-11-04 15:17:43 +00:00
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
// Make sure we re-keep track of the OpaqueStream box
|
|
|
|
let _ = Box::from_raw(boxed_ctx);
|
2024-10-31 23:50:42 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
info!("IPC channel is closed, exiting the scheduler loop");
|
|
|
|
break;
|
2024-10-28 21:44:47 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-10-24 14:42:50 +00:00
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl Backend for LlamaCppBackend {
|
2024-10-04 08:42:31 +00:00
|
|
|
fn schedule(
|
|
|
|
&self,
|
2024-10-31 23:50:42 +00:00
|
|
|
request: ValidGenerateRequest,
|
2024-11-04 15:17:43 +00:00
|
|
|
) -> Result<UnboundedReceiverStream<InferResult>, InferError> {
|
2024-10-31 23:50:42 +00:00
|
|
|
if let Some(input_ids) = request.input_ids {
|
|
|
|
let (sx, rx) = unbounded_channel();
|
|
|
|
let sampling_params = SamplingParams::from(&request.parameters);
|
|
|
|
let generation_params = GenerationParams::from(&request.stopping_parameters);
|
|
|
|
|
2024-11-04 15:17:43 +00:00
|
|
|
let ctx = GenerationContext {
|
2024-10-31 23:50:42 +00:00
|
|
|
input_tokens: Arc::clone(&input_ids),
|
|
|
|
generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize),
|
|
|
|
generation_params,
|
|
|
|
sampling_params,
|
|
|
|
};
|
|
|
|
|
2024-11-09 21:10:33 +00:00
|
|
|
match self {
|
|
|
|
LlamaCppBackend::Single(worker) => match worker.sender.send((ctx, sx)) {
|
|
|
|
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
|
|
|
|
Err(_) => Err(InferError::GenerationError(
|
|
|
|
"Failed to sent the request".to_string(),
|
|
|
|
)),
|
|
|
|
},
|
2024-10-31 23:50:42 +00:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
Err(InferError::GenerationError(
|
|
|
|
"Unsupported modalities".to_string(),
|
|
|
|
))
|
|
|
|
}
|
2024-10-04 08:42:31 +00:00
|
|
|
}
|
|
|
|
|
2024-10-24 14:42:50 +00:00
|
|
|
async fn health(&self, _: bool) -> bool {
|
|
|
|
true
|
2024-10-04 08:42:31 +00:00
|
|
|
}
|
|
|
|
}
|