Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-06 14:58:44 +00:00
parent 5367d94f34
commit 809e288b5a
No known key found for this signature in database
2 changed files with 120 additions and 119 deletions

View File

@ -35,9 +35,9 @@ impl FromStr for LlamacppSplitMode {
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() { match s.to_lowercase().as_str() {
"layer" => Ok(LlamacppSplitMode::Layer), "layer" => Ok(LlamacppSplitMode::Layer),
"row" => Ok(LlamacppSplitMode::Row), "row" => Ok(LlamacppSplitMode::Row),
_ => match s.parse::<usize>() { _ => match s.parse::<usize>() {
Ok(n) => Ok(LlamacppSplitMode::GPU(n)), Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()), Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
}, },
} }
@ -93,37 +93,37 @@ pub enum LlamacppGGMLType {
impl LlamacppGGMLType { impl LlamacppGGMLType {
fn to_ggml_type(&self) -> llamacpp::ggml_type { fn to_ggml_type(&self) -> llamacpp::ggml_type {
match self { match self {
LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32, LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16, LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0, LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1, LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0, LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1, LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0, LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1, LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K, LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K, LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K, LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K, LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K, LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K, LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS, LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS, LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS, LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S, LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL, LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S, LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S, LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS, LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8, LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16, LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32, LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64, LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64, LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M, LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16, LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0, LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0, LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
} }
} }
} }
@ -177,18 +177,18 @@ impl LlamacppRequest {
tx: UnboundedSender<Result<InferStreamResponse, InferError>>, tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
) -> Option<Self> { ) -> Option<Self> {
from.input_ids.as_ref().map(|input_ids| LlamacppRequest { from.input_ids.as_ref().map(|input_ids| LlamacppRequest {
input_ids: input_ids.iter().map(|&x| x as i32).collect(), input_ids: input_ids.iter().map(|&x| x as i32).collect(),
top_k: from.parameters.top_k as _, top_k: from.parameters.top_k as _,
top_p: from.parameters.top_p as _, top_p: from.parameters.top_p as _,
typical_p: from.parameters.typical_p as _, typical_p: from.parameters.typical_p as _,
min_keep: 0, // disabled min_keep: 0, // disabled
temp: from.parameters.temperature as _, temp: from.parameters.temperature as _,
seed: from.parameters.seed as _, seed: from.parameters.seed as _,
penalty_last_n: 64, // 0 = disabled, -1 = context size penalty_last_n: 64, // 0 = disabled, -1 = context size
penalty_repeat: from.parameters.repetition_penalty as _, penalty_repeat: from.parameters.repetition_penalty as _,
penalty_freq: from.parameters.frequency_penalty as _, penalty_freq: from.parameters.frequency_penalty as _,
penalty_present: 0.0, // disabled penalty_present: 0.0, // disabled
max_new_tokens: from.stopping_parameters.max_new_tokens as _, max_new_tokens: from.stopping_parameters.max_new_tokens as _,
tx, tx,
time: Instant::now(), time: Instant::now(),
}) })
@ -213,10 +213,10 @@ extern "C" fn llamacpp_log_callback(
match level { match level {
llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg), llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg), llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg), llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg), llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
_ => trace!(target: "llamacpp", "{}", rmsg), _ => trace!(target: "llamacpp", "{}", rmsg),
} }
} }
@ -229,14 +229,14 @@ impl Llamacpp {
params.n_gpu_layers = conf.n_gpu_layers as _; params.n_gpu_layers = conf.n_gpu_layers as _;
params.split_mode = match conf.split_mode { params.split_mode = match conf.split_mode {
LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE, LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER, LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW, LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
}; };
params.main_gpu = match conf.split_mode { params.main_gpu = match conf.split_mode {
LlamacppSplitMode::GPU(n) => n as _, LlamacppSplitMode::GPU(n) => n as _,
_ => 0, _ => 0,
}; };
params.use_mmap = conf.use_mmap; params.use_mmap = conf.use_mmap;
params.use_mlock = conf.use_mlock; params.use_mlock = conf.use_mlock;
llamacpp::model_load_from_file(gguf.as_ptr(), params) llamacpp::model_load_from_file(gguf.as_ptr(), params)
}; };
@ -245,32 +245,28 @@ impl Llamacpp {
} }
let ctx = unsafe { let ctx = unsafe {
let mut params = llamacpp::context_default_params(); let mut params = llamacpp::context_default_params();
params.n_ctx = conf.max_batch_total_tokens as _; params.n_ctx = conf.max_batch_total_tokens as _;
params.n_batch = conf.max_batch_total_tokens as _; params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_physical_batch_total_tokens as _; params.n_ubatch = conf.max_physical_batch_total_tokens as _;
params.n_seq_max = conf.max_batch_size as _; params.n_seq_max = conf.max_batch_size as _;
params.n_threads = conf.n_threads as _; params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads_batch as _; params.n_threads_batch = conf.n_threads_batch as _;
params.defrag_thold = conf.defrag_threshold; params.defrag_thold = conf.defrag_threshold;
params.offload_kqv = conf.offload_kqv; params.offload_kqv = conf.offload_kqv;
params.flash_attn = conf.flash_attention; params.flash_attn = conf.flash_attention;
params.type_k = conf.type_k.to_ggml_type(); params.type_k = conf.type_k.to_ggml_type();
params.type_v = conf.type_v.to_ggml_type(); params.type_v = conf.type_v.to_ggml_type();
params.no_perf = true; params.no_perf = true;
llamacpp::init_from_model(model, params) llamacpp::init_from_model(model, params)
}; };
if ctx.is_null() { if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string())); return Err(BackendError::Llamacpp("Failed to init context".to_string()));
} }
let vocab = unsafe { let vocab = unsafe { llamacpp::model_get_vocab(model) };
llamacpp::model_get_vocab(model)
};
if vocab.is_null() { if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
} }
let n_tokens = unsafe { let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
llamacpp::vocab_n_tokens(vocab)
};
let mut logprobs = Vec::with_capacity(n_tokens as usize); let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens { for token in 0..n_tokens {
@ -280,16 +276,18 @@ impl Llamacpp {
p: 0.0, p: 0.0,
}); });
} }
let batch = unsafe { let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) Ok(Llamacpp {
}; model,
Ok(Llamacpp{model, ctx, vocab, logprobs, batch}) ctx,
vocab,
logprobs,
batch,
})
} }
fn decode(&mut self) -> i32 { fn decode(&mut self) -> i32 {
unsafe { unsafe { llamacpp::decode(self.ctx, self.batch) }
llamacpp::decode(self.ctx, self.batch)
}
} }
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) { fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
@ -344,18 +342,10 @@ impl LlamacppSampler {
error!("Failed to init sampler"); error!("Failed to init sampler");
return None; return None;
} }
let top_k = unsafe { let top_k = unsafe { llamacpp::sampler_init_top_k(req.top_k) };
llamacpp::sampler_init_top_k(req.top_k) let top_p = unsafe { llamacpp::sampler_init_top_p(req.top_p, req.min_keep) };
}; let typical_p = unsafe { llamacpp::sampler_init_typical(req.typical_p, req.min_keep) };
let top_p = unsafe { let temp = unsafe { llamacpp::sampler_init_temp(req.temp) };
llamacpp::sampler_init_top_p(req.top_p, req.min_keep)
};
let typical_p = unsafe {
llamacpp::sampler_init_typical(req.typical_p, req.min_keep)
};
let temp = unsafe {
llamacpp::sampler_init_temp(req.temp)
};
let penalties = unsafe { let penalties = unsafe {
llamacpp::sampler_init_penalties( llamacpp::sampler_init_penalties(
req.penalty_last_n, req.penalty_last_n,
@ -364,9 +354,7 @@ impl LlamacppSampler {
req.penalty_present, req.penalty_present,
) )
}; };
let dist = unsafe { let dist = unsafe { llamacpp::sampler_init_dist(req.seed) };
llamacpp::sampler_init_dist(req.seed)
};
let all = &[ let all = &[
("top_k", top_k), ("top_k", top_k),
("top_p", top_p), ("top_p", top_p),
@ -389,14 +377,12 @@ impl LlamacppSampler {
unsafe { llamacpp::sampler_free(chain) }; unsafe { llamacpp::sampler_free(chain) };
None None
} else { } else {
Some(LlamacppSampler{chain}) Some(LlamacppSampler { chain })
} }
} }
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) { fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
let logits = unsafe { let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
llamacpp::get_logits_ith(llamacpp.ctx, idx as _)
};
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() { for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = llamacpp::llama_token_data { *logprob = llamacpp::llama_token_data {
id: token as _, id: token as _,
@ -454,11 +440,11 @@ impl LlamacppBackend {
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut()); llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
llamacpp::backend_init(); llamacpp::backend_init();
llamacpp::numa_init(match conf.numa { llamacpp::numa_init(match conf.numa {
LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED, LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE, LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE, LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL, LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR, LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
}); });
}); });
@ -474,7 +460,8 @@ impl LlamacppBackend {
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| { let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
if !requests.is_empty() { if !requests.is_empty() {
let _ = sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size))); let _ =
sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
*n_tokens = 0; *n_tokens = 0;
} }
}; };
@ -538,8 +525,8 @@ impl LlamacppBackend {
for (pos, &token_id) in request.input_ids.iter().enumerate() { for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push( llamacpp.batch_push(
token_id as llamacpp::llama_token, token_id as llamacpp::llama_token,
pos as llamacpp::llama_pos, pos as llamacpp::llama_pos,
seq_id as llamacpp::llama_seq_id, seq_id as llamacpp::llama_seq_id,
pos == last_pos, // check samplers pos == last_pos, // check samplers
); );
} }
@ -559,7 +546,9 @@ impl LlamacppBackend {
warn!("llama_decode failed, clearing kv cache"); warn!("llama_decode failed, clearing kv cache");
llamacpp.clear_kv_cache(-1); llamacpp.clear_kv_cache(-1);
for seq in seqs.iter_mut() { for seq in seqs.iter_mut() {
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false; seq.running = false;
} }
break; break;
@ -576,7 +565,9 @@ impl LlamacppBackend {
Ok(piece) => piece, Ok(piece) => piece,
Err(e) => { Err(e) => {
error!("Failed to decode token: {e}"); error!("Failed to decode token: {e}");
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false; seq.running = false;
continue; continue;
} }
@ -617,17 +608,20 @@ impl LlamacppBackend {
seq.running = false; seq.running = false;
continue; continue;
} }
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate { let _ = requests[seq.id]
token, .tx
top_tokens: vec![], .send(Ok(InferStreamResponse::Intermediate {
})); token,
top_tokens: vec![],
}));
} }
// generate a new batch // generate a new batch
llamacpp.batch.n_tokens = 0; llamacpp.batch.n_tokens = 0;
for seq in seqs.iter_mut() { for seq in seqs.iter_mut() {
if seq.running { if seq.running {
seq.batch_pos = llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true); seq.batch_pos =
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
seq.pos += 1; seq.pos += 1;
} else { } else {
llamacpp.clear_kv_cache(seq.id as _); llamacpp.clear_kv_cache(seq.id as _);
@ -636,7 +630,14 @@ impl LlamacppBackend {
} }
} }
}); });
(Self{tx, status: status_rx}, ok_rx, shutdown_tx) (
Self {
tx,
status: status_rx,
},
ok_rx,
shutdown_tx,
)
} }
} }

View File

@ -222,23 +222,23 @@ async fn main() -> Result<(), RouterError> {
let (backend, ok, shutdown) = LlamacppBackend::new( let (backend, ok, shutdown) = LlamacppBackend::new(
LlamacppConfig { LlamacppConfig {
model_gguf: args.model_gguf, model_gguf: args.model_gguf,
n_threads, n_threads,
n_threads_batch, n_threads_batch,
n_gpu_layers: args.n_gpu_layers, n_gpu_layers: args.n_gpu_layers,
split_mode: args.split_mode, split_mode: args.split_mode,
defrag_threshold: args.defrag_threshold, defrag_threshold: args.defrag_threshold,
numa: args.numa, numa: args.numa,
use_mmap: args.use_mmap, use_mmap: args.use_mmap,
use_mlock: args.use_mlock, use_mlock: args.use_mlock,
flash_attention: args.flash_attention, flash_attention: args.flash_attention,
type_k: args.type_k, type_k: args.type_k,
type_v: args.type_v, type_v: args.type_v,
offload_kqv: args.offload_kqv, offload_kqv: args.offload_kqv,
max_batch_total_tokens, max_batch_total_tokens,
max_physical_batch_total_tokens, max_physical_batch_total_tokens,
max_batch_size, max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(5), batch_timeout: tokio::time::Duration::from_millis(5),
}, },
tokenizer, tokenizer,
); );
@ -261,7 +261,7 @@ async fn main() -> Result<(), RouterError> {
args.max_input_tokens, args.max_input_tokens,
args.max_total_tokens, args.max_total_tokens,
args.validation_workers, args.validation_workers,
None, // api_key None, // api_key
args.model_id, // tokenizer_name args.model_id, // tokenizer_name
args.tokenizer_config_path, args.tokenizer_config_path,
Some(args.revision), Some(args.revision),