diff --git a/Dockerfile_llamacpp b/Dockerfile_llamacpp index 7f083ec2..006a3204 100644 --- a/Dockerfile_llamacpp +++ b/Dockerfile_llamacpp @@ -18,7 +18,7 @@ RUN apt-get install -y \ # nvidia-cuda-toolkit # -DGGML_CUDA=ON \ -ENV LLAMA_VERSION=b4585 +ENV LLAMA_VERSION=b4599 RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \ && cd llama.cpp \ && cmake -B build \ diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 85976779..0d3c3950 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -25,9 +25,9 @@ pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: usize, pub max_batch_total_tokens: usize, - pub max_batch_size: Option, + pub max_batch_size: usize, pub batch_timeout: Duration, - pub n_threads: isize, + pub n_threads: usize, pub use_mmap: bool, pub use_mlock: bool, pub flash_attention: bool, @@ -88,6 +88,7 @@ struct Llamacpp { model: *mut bindings::llama_model, ctx: *mut bindings::llama_context, vocab: *const bindings::llama_vocab, + logprobs: Vec, batch: bindings::llama_batch, n_ctx: u32, } @@ -115,7 +116,7 @@ impl Llamacpp { let model = unsafe { let mut params = bindings::llama_model_default_params(); - params.use_mmap = conf.use_mmap; + params.use_mmap = conf.use_mmap; params.use_mlock = conf.use_mlock; bindings::llama_model_load_from_file(gguf.as_ptr(), params) }; @@ -124,11 +125,14 @@ impl Llamacpp { } let ctx = unsafe { let mut params = bindings::llama_context_default_params(); - params.n_ctx = conf.n_ctx as _; - params.n_threads = conf.n_threads as _; - params.n_threads_batch = conf.n_threads as _; - params.flash_attn = conf.flash_attention; - params.no_perf = true; + params.n_ctx = conf.n_ctx as _; + params.n_batch = conf.max_batch_total_tokens as _; + params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ? + params.n_seq_max = conf.max_batch_size as _; + params.n_threads = conf.n_threads as _; + params.n_threads_batch = conf.n_threads as _; // TODO ? + params.flash_attn = conf.flash_attention; + params.no_perf = true; bindings::llama_init_from_model(model, params) }; if ctx.is_null() { @@ -142,11 +146,30 @@ impl Llamacpp { if vocab.is_null() { return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); } + let n_tokens = unsafe { + bindings::llama_vocab_n_tokens(vocab) + }; + let mut logprobs = Vec::with_capacity(n_tokens as usize); + + for token in 0..n_tokens { + logprobs.push(bindings::llama_token_data { + id: token, + logit: 0.0, + p: 0.0, + }); + } let batch = unsafe { bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1) }; - // TODO check batch - Ok(Llamacpp{model, ctx, vocab, n_ctx, batch}) + Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) + } + + fn batch_clear_logits(&mut self) { + for n in 0..self.batch.n_tokens as usize{ + unsafe { + *self.batch.logits.add(n) = 0 as i8; + } + } } fn batch_push( @@ -156,6 +179,7 @@ impl Llamacpp { seq_ids: &[bindings::llama_seq_id], logits: bool, ) { + debug!("push {token} {pos} {logits}"); // TODO check evertyhing.. let n = self.batch.n_tokens as usize; @@ -279,9 +303,29 @@ impl LlamacppSampler { } } - fn sample(&self, llamacpp: &Llamacpp) -> bindings::llama_token { - // use apply/accept ? - unsafe { bindings::llama_sampler_sample(self.chain, llamacpp.ctx, -1) }// -1 ? + fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) { + let logits = unsafe { + bindings::llama_get_logits_ith(llamacpp.ctx, idx as _) + }; + for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() { + *logprob = bindings::llama_token_data { + id: token as _, + logit: unsafe { *logits.offset(token as _) }, + p: 0.0, + }; + } + let mut view = bindings::llama_token_data_array { + data: llamacpp.logprobs.as_mut_ptr(), + size: llamacpp.logprobs.len(), + selected: -1, + sorted: false, + }; + unsafe { + bindings::llama_sampler_apply(self.chain, &mut view); + let logprob = *view.data.offset(view.selected as _); + bindings::llama_sampler_accept(self.chain, logprob.id); + (logprob.id, logprob.logit) // maybe p.ln() ? + } } } @@ -321,14 +365,12 @@ impl LlamacppBackend { match timeout(conf.batch_timeout, rx.recv()).await { Ok(None) => break, // closed Ok(Some(request)) => { - if let Some(max_batch_size) = conf.max_batch_size { - if requests.len() + 1 == max_batch_size { - requests.push(request); - let _ = sync_tx.send(requests); - n_tokens = 0; - requests = Vec::new(); - continue; - } + if requests.len() + 1 == conf.max_batch_size { + requests.push(request); + let _ = sync_tx.send(requests); + n_tokens = 0; + requests = Vec::new(); + continue; } if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize { let _ = sync_tx.send(requests); @@ -378,6 +420,8 @@ impl LlamacppBackend { true, ); } + let mut pos = request.input_ids.len(); + // TODO: close this loop :) // TODO: move up for perf ? @@ -409,15 +453,12 @@ impl LlamacppBackend { break; }, }; - let next = sampler.sample(&llamacpp); - n_tokens += llamacpp.batch.n_tokens as usize; - n_new_tokens += llamacpp.batch.n_tokens as usize; + let idx = llamacpp.batch.n_tokens as usize - 1; + let (next, logprob) = sampler.sample(&mut llamacpp, idx); + n_new_tokens += 1; debug!("tokens: {n_tokens} new: {n_new_tokens}"); - let logits = unsafe { - *bindings::llama_get_logits_ith(llamacpp.ctx, -1) - }; let kv_cache_used_cells = unsafe { bindings::llama_get_kv_cache_used_cells(llamacpp.ctx) }; @@ -437,7 +478,7 @@ impl LlamacppBackend { let token = Token { id: next as _, text: piece, - logprob: logits as _, + logprob: logprob, special: special, }; let finish: Option = { @@ -471,7 +512,9 @@ impl LlamacppBackend { top_tokens: vec![], })); llamacpp.batch.n_tokens = 0; - llamacpp.batch_push(next, n_tokens as _, &[0], true); + // llamacpp.batch_clear_logits(); + llamacpp.batch_push(next, pos as _, &[0], true); + pos += 1; } } } // TODO remove this diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 33614059..7eae8315 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -30,7 +30,7 @@ struct Args { /// Number of threads to use for inference. #[clap(default_value = "1", long, env)] - n_threads: isize, + n_threads: usize, #[clap(default_value = "true", long, env)] /// Whether to use memory mapping. @@ -77,8 +77,8 @@ struct Args { // max_waiting_tokens: usize, /// Maximum number of requests per batch - #[clap(long, env)] - max_batch_size: Option, + #[clap(default_value = "1", long, env)] + max_batch_size: usize, /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] @@ -145,12 +145,10 @@ async fn main() -> Result<(), RouterError> { "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); } - if let Some(max_batch_size) = args.max_batch_size { - if max_batch_size * args.max_total_tokens > args.max_batch_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), - )); - } + if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), + )); } if args.max_batch_total_tokens > args.n_ctx { return Err(RouterError::ArgumentValidation(