From 38b33e9698cf672e36ea86306549395500d924a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Mon, 3 Feb 2025 12:39:28 +0000 Subject: [PATCH] Add --type-v & --type-k MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 79 ++++++++++++++++++++++++++++++++ backends/llamacpp/src/main.rs | 12 ++++- 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index bf4b19e3..04160cc4 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -52,6 +52,81 @@ pub enum LlamacppNuma { Mirror, } +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, clap::ValueEnum)] +pub enum LlamacppGGMLType { + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2_K, + Q3_K, + Q4_K, + Q5_K, + Q6_K, + Q8_K, + IQ2_XXS, + IQ2_XS, + IQ3_XXS, + IQ1_S, + IQ4_NL, + IQ3_S, + IQ2_S, + IQ4_XS, + I8, + I16, + I32, + I64, + F64, + IQ1_M, + BF16, + TQ1_0, + TQ2_0, +} + +// TODO: macro +impl LlamacppGGMLType { + fn to_ggml_type(&self) -> bindings::ggml_type { + match self { + LlamacppGGMLType::F32 => bindings::GGML_TYPE_F32, + LlamacppGGMLType::F16 => bindings::GGML_TYPE_F16, + LlamacppGGMLType::Q4_0 => bindings::GGML_TYPE_Q4_0, + LlamacppGGMLType::Q4_1 => bindings::GGML_TYPE_Q4_1, + LlamacppGGMLType::Q5_0 => bindings::GGML_TYPE_Q5_0, + LlamacppGGMLType::Q5_1 => bindings::GGML_TYPE_Q5_1, + LlamacppGGMLType::Q8_0 => bindings::GGML_TYPE_Q8_0, + LlamacppGGMLType::Q8_1 => bindings::GGML_TYPE_Q8_1, + LlamacppGGMLType::Q2_K => bindings::GGML_TYPE_Q2_K, + LlamacppGGMLType::Q3_K => bindings::GGML_TYPE_Q3_K, + LlamacppGGMLType::Q4_K => bindings::GGML_TYPE_Q4_K, + LlamacppGGMLType::Q5_K => bindings::GGML_TYPE_Q5_K, + LlamacppGGMLType::Q6_K => bindings::GGML_TYPE_Q6_K, + LlamacppGGMLType::Q8_K => bindings::GGML_TYPE_Q8_K, + LlamacppGGMLType::IQ2_XXS => bindings::GGML_TYPE_IQ2_XXS, + LlamacppGGMLType::IQ2_XS => bindings::GGML_TYPE_IQ2_XS, + LlamacppGGMLType::IQ3_XXS => bindings::GGML_TYPE_IQ3_XXS, + LlamacppGGMLType::IQ1_S => bindings::GGML_TYPE_IQ1_S, + LlamacppGGMLType::IQ4_NL => bindings::GGML_TYPE_IQ4_NL, + LlamacppGGMLType::IQ3_S => bindings::GGML_TYPE_IQ3_S, + LlamacppGGMLType::IQ2_S => bindings::GGML_TYPE_IQ2_S, + LlamacppGGMLType::IQ4_XS => bindings::GGML_TYPE_IQ4_XS, + LlamacppGGMLType::I8 => bindings::GGML_TYPE_I8, + LlamacppGGMLType::I16 => bindings::GGML_TYPE_I16, + LlamacppGGMLType::I32 => bindings::GGML_TYPE_I32, + LlamacppGGMLType::I64 => bindings::GGML_TYPE_I64, + LlamacppGGMLType::F64 => bindings::GGML_TYPE_F64, + LlamacppGGMLType::IQ1_M => bindings::GGML_TYPE_IQ1_M, + LlamacppGGMLType::BF16 => bindings::GGML_TYPE_BF16, + LlamacppGGMLType::TQ1_0 => bindings::GGML_TYPE_TQ1_0, + LlamacppGGMLType::TQ2_0 => bindings::GGML_TYPE_TQ2_0, + } + } +} + pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: usize, @@ -69,6 +144,8 @@ pub struct LlamacppConfig { pub use_mlock: bool, pub offload_kqv: bool, pub flash_attention: bool, + pub type_k: LlamacppGGMLType, + pub type_v: LlamacppGGMLType, } #[derive(Debug)] @@ -182,6 +259,8 @@ impl Llamacpp { params.defrag_thold = conf.defrag_threshold; params.offload_kqv = conf.offload_kqv; params.flash_attn = conf.flash_attention; + params.type_k = conf.type_k.to_ggml_type(); + params.type_v = conf.type_v.to_ggml_type(); params.no_perf = true; bindings::llama_init_from_model(model, params) }; diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 55881b13..dba391c0 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -1,6 +1,6 @@ mod backend; -use backend::{LlamacppNuma, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; +use backend::{LlamacppNuma, LlamacppGGMLType, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; use clap::{Parser}; use text_generation_router::{logging, server, usage_stats}; use thiserror::Error; @@ -68,6 +68,14 @@ struct Args { #[clap(default_value = "true", long, env)] flash_attention: bool, + /// Use data type for K cache. + #[clap(default_value = "f16", value_enum, long, env)] + type_k: LlamacppGGMLType, + + /// Use data type for V cache. + #[clap(default_value = "f16", value_enum, long, env)] + type_v: LlamacppGGMLType, + /// TODO #[clap(default_value = "2", long, env)] validation_workers: usize, @@ -226,6 +234,8 @@ async fn main() -> Result<(), RouterError> { use_mmap: args.use_mmap, use_mlock: args.use_mlock, flash_attention: args.flash_attention, + type_k: args.type_k, + type_v: args.type_v, offload_kqv: args.offload_kqv, max_batch_total_tokens: args.max_batch_total_tokens, max_physical_batch_total_tokens: max_physical_batch_total_tokens,