From 09a745f1b86d59324d1b389a9afcc71763a53187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 5 Feb 2025 11:31:58 +0000 Subject: [PATCH] Remove n_ctx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 13 ++----------- backends/llamacpp/src/main.rs | 10 ---------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index aa44df31..d81137e6 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -130,7 +130,6 @@ impl LlamacppGGMLType { pub struct LlamacppConfig { pub model_gguf: String, - pub n_ctx: usize, pub max_batch_total_tokens: usize, pub max_physical_batch_total_tokens: usize, pub max_batch_size: usize, @@ -206,7 +205,6 @@ struct Llamacpp { vocab: *const llamacpp::llama_vocab, logprobs: Vec, batch: llamacpp::llama_batch, - n_ctx: u32, } extern "C" fn llamacpp_log_callback( @@ -251,7 +249,7 @@ impl Llamacpp { } let ctx = unsafe { let mut params = llamacpp::context_default_params(); - params.n_ctx = conf.n_ctx as _; + params.n_ctx = conf.max_batch_total_tokens as _; params.n_batch = conf.max_batch_total_tokens as _; params.n_ubatch = conf.max_physical_batch_total_tokens as _; params.n_seq_max = conf.max_batch_size as _; @@ -268,8 +266,6 @@ impl Llamacpp { if ctx.is_null() { return Err(BackendError::Llamacpp("Failed to init context".to_string())) } - let n_ctx = unsafe { llamacpp::n_ctx(ctx) }; - let vocab = unsafe { llamacpp::model_get_vocab(model) }; @@ -291,7 +287,7 @@ impl Llamacpp { let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) }; - Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) + Ok(Llamacpp{model, ctx, vocab, logprobs, batch}) } fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) { @@ -559,9 +555,6 @@ impl LlamacppBackend { } break; } - let kv_cache_used_cells = unsafe { - llamacpp::get_kv_cache_used_cells(llamacpp.ctx) - }; for seq in seqs.iter_mut() { if !seq.running { continue; @@ -595,8 +588,6 @@ impl LlamacppBackend { Some(FinishReason::EndOfSequenceToken) } else if seq.n_new_tokens == requests[seq.id].max_new_tokens { Some(FinishReason::Length) - } else if kv_cache_used_cells == llamacpp.n_ctx as i32 { - Some(FinishReason::Length) // TODO: check } else { None } diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 5548773b..310ca8f1 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -24,10 +24,6 @@ struct Args { #[clap(long, env)] model_gguf: String, // TODO Option() with hf->gguf & quantize - /// Context size for the model. - #[clap(default_value = "4096", long, env)] - n_ctx: usize, - /// Number of threads to use for generation. #[clap(long, env)] n_threads: Option, @@ -198,11 +194,6 @@ async fn main() -> Result<(), RouterError> { "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); } - if args.max_batch_total_tokens > args.n_ctx { - return Err(RouterError::ArgumentValidation( - "`max_batch_total_tokens` must be <= `n_ctx`".to_string(), - )); - } // TODO: check if we use the same cache of Server // check if llamacpp is faster @@ -224,7 +215,6 @@ async fn main() -> Result<(), RouterError> { let (backend, ok, shutdown) = LlamacppBackend::new( LlamacppConfig { model_gguf: args.model_gguf, - n_ctx: args.n_ctx, n_threads: n_threads, n_threads_batch: n_threads_batch, n_gpu_layers: args.n_gpu_layers,