Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-06 14:58:44 +00:00
parent 5367d94f34
commit 809e288b5a
No known key found for this signature in database
2 changed files with 120 additions and 119 deletions

View File

@ -262,15 +262,11 @@ impl Llamacpp {
if ctx.is_null() { if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string())); return Err(BackendError::Llamacpp("Failed to init context".to_string()));
} }
let vocab = unsafe { let vocab = unsafe { llamacpp::model_get_vocab(model) };
llamacpp::model_get_vocab(model)
};
if vocab.is_null() { if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
} }
let n_tokens = unsafe { let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
llamacpp::vocab_n_tokens(vocab)
};
let mut logprobs = Vec::with_capacity(n_tokens as usize); let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens { for token in 0..n_tokens {
@ -280,16 +276,18 @@ impl Llamacpp {
p: 0.0, p: 0.0,
}); });
} }
let batch = unsafe { let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) Ok(Llamacpp {
}; model,
Ok(Llamacpp{model, ctx, vocab, logprobs, batch}) ctx,
vocab,
logprobs,
batch,
})
} }
fn decode(&mut self) -> i32 { fn decode(&mut self) -> i32 {
unsafe { unsafe { llamacpp::decode(self.ctx, self.batch) }
llamacpp::decode(self.ctx, self.batch)
}
} }
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) { fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
@ -344,18 +342,10 @@ impl LlamacppSampler {
error!("Failed to init sampler"); error!("Failed to init sampler");
return None; return None;
} }
let top_k = unsafe { let top_k = unsafe { llamacpp::sampler_init_top_k(req.top_k) };
llamacpp::sampler_init_top_k(req.top_k) let top_p = unsafe { llamacpp::sampler_init_top_p(req.top_p, req.min_keep) };
}; let typical_p = unsafe { llamacpp::sampler_init_typical(req.typical_p, req.min_keep) };
let top_p = unsafe { let temp = unsafe { llamacpp::sampler_init_temp(req.temp) };
llamacpp::sampler_init_top_p(req.top_p, req.min_keep)
};
let typical_p = unsafe {
llamacpp::sampler_init_typical(req.typical_p, req.min_keep)
};
let temp = unsafe {
llamacpp::sampler_init_temp(req.temp)
};
let penalties = unsafe { let penalties = unsafe {
llamacpp::sampler_init_penalties( llamacpp::sampler_init_penalties(
req.penalty_last_n, req.penalty_last_n,
@ -364,9 +354,7 @@ impl LlamacppSampler {
req.penalty_present, req.penalty_present,
) )
}; };
let dist = unsafe { let dist = unsafe { llamacpp::sampler_init_dist(req.seed) };
llamacpp::sampler_init_dist(req.seed)
};
let all = &[ let all = &[
("top_k", top_k), ("top_k", top_k),
("top_p", top_p), ("top_p", top_p),
@ -389,14 +377,12 @@ impl LlamacppSampler {
unsafe { llamacpp::sampler_free(chain) }; unsafe { llamacpp::sampler_free(chain) };
None None
} else { } else {
Some(LlamacppSampler{chain}) Some(LlamacppSampler { chain })
} }
} }
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) { fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
let logits = unsafe { let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
llamacpp::get_logits_ith(llamacpp.ctx, idx as _)
};
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() { for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = llamacpp::llama_token_data { *logprob = llamacpp::llama_token_data {
id: token as _, id: token as _,
@ -474,7 +460,8 @@ impl LlamacppBackend {
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| { let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
if !requests.is_empty() { if !requests.is_empty() {
let _ = sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size))); let _ =
sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
*n_tokens = 0; *n_tokens = 0;
} }
}; };
@ -559,7 +546,9 @@ impl LlamacppBackend {
warn!("llama_decode failed, clearing kv cache"); warn!("llama_decode failed, clearing kv cache");
llamacpp.clear_kv_cache(-1); llamacpp.clear_kv_cache(-1);
for seq in seqs.iter_mut() { for seq in seqs.iter_mut() {
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false; seq.running = false;
} }
break; break;
@ -576,7 +565,9 @@ impl LlamacppBackend {
Ok(piece) => piece, Ok(piece) => piece,
Err(e) => { Err(e) => {
error!("Failed to decode token: {e}"); error!("Failed to decode token: {e}");
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false; seq.running = false;
continue; continue;
} }
@ -617,7 +608,9 @@ impl LlamacppBackend {
seq.running = false; seq.running = false;
continue; continue;
} }
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate { let _ = requests[seq.id]
.tx
.send(Ok(InferStreamResponse::Intermediate {
token, token,
top_tokens: vec![], top_tokens: vec![],
})); }));
@ -627,7 +620,8 @@ impl LlamacppBackend {
for seq in seqs.iter_mut() { for seq in seqs.iter_mut() {
if seq.running { if seq.running {
seq.batch_pos = llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true); seq.batch_pos =
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
seq.pos += 1; seq.pos += 1;
} else { } else {
llamacpp.clear_kv_cache(seq.id as _); llamacpp.clear_kv_cache(seq.id as _);
@ -636,7 +630,14 @@ impl LlamacppBackend {
} }
} }
}); });
(Self{tx, status: status_rx}, ok_rx, shutdown_tx) (
Self {
tx,
status: status_rx,
},
ok_rx,
shutdown_tx,
)
} }
} }