Handle ctx args & fix sampling

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-30 22:41:26 +00:00
parent a7b4b04cb5
commit 8d2dfdf668
No known key found for this signature in database
3 changed files with 80 additions and 39 deletions

View File

@ -18,7 +18,7 @@ RUN apt-get install -y \
# nvidia-cuda-toolkit
# -DGGML_CUDA=ON \
ENV LLAMA_VERSION=b4585
ENV LLAMA_VERSION=b4599
RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \
&& cd llama.cpp \
&& cmake -B build \

View File

@ -25,9 +25,9 @@ pub struct LlamacppConfig {
pub model_gguf: String,
pub n_ctx: usize,
pub max_batch_total_tokens: usize,
pub max_batch_size: Option<usize>,
pub max_batch_size: usize,
pub batch_timeout: Duration,
pub n_threads: isize,
pub n_threads: usize,
pub use_mmap: bool,
pub use_mlock: bool,
pub flash_attention: bool,
@ -88,6 +88,7 @@ struct Llamacpp {
model: *mut bindings::llama_model,
ctx: *mut bindings::llama_context,
vocab: *const bindings::llama_vocab,
logprobs: Vec<bindings::llama_token_data>,
batch: bindings::llama_batch,
n_ctx: u32,
}
@ -115,7 +116,7 @@ impl Llamacpp {
let model = unsafe {
let mut params = bindings::llama_model_default_params();
params.use_mmap = conf.use_mmap;
params.use_mmap = conf.use_mmap;
params.use_mlock = conf.use_mlock;
bindings::llama_model_load_from_file(gguf.as_ptr(), params)
};
@ -124,11 +125,14 @@ impl Llamacpp {
}
let ctx = unsafe {
let mut params = bindings::llama_context_default_params();
params.n_ctx = conf.n_ctx as _;
params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads as _;
params.flash_attn = conf.flash_attention;
params.no_perf = true;
params.n_ctx = conf.n_ctx as _;
params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ?
params.n_seq_max = conf.max_batch_size as _;
params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads as _; // TODO ?
params.flash_attn = conf.flash_attention;
params.no_perf = true;
bindings::llama_init_from_model(model, params)
};
if ctx.is_null() {
@ -142,11 +146,30 @@ impl Llamacpp {
if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
}
let n_tokens = unsafe {
bindings::llama_vocab_n_tokens(vocab)
};
let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens {
logprobs.push(bindings::llama_token_data {
id: token,
logit: 0.0,
p: 0.0,
});
}
let batch = unsafe {
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
};
// TODO check batch
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
}
fn batch_clear_logits(&mut self) {
for n in 0..self.batch.n_tokens as usize{
unsafe {
*self.batch.logits.add(n) = 0 as i8;
}
}
}
fn batch_push(
@ -156,6 +179,7 @@ impl Llamacpp {
seq_ids: &[bindings::llama_seq_id],
logits: bool,
) {
debug!("push {token} {pos} {logits}");
// TODO check evertyhing..
let n = self.batch.n_tokens as usize;
@ -279,9 +303,29 @@ impl LlamacppSampler {
}
}
fn sample(&self, llamacpp: &Llamacpp) -> bindings::llama_token {
// use apply/accept ?
unsafe { bindings::llama_sampler_sample(self.chain, llamacpp.ctx, -1) }// -1 ?
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) {
let logits = unsafe {
bindings::llama_get_logits_ith(llamacpp.ctx, idx as _)
};
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = bindings::llama_token_data {
id: token as _,
logit: unsafe { *logits.offset(token as _) },
p: 0.0,
};
}
let mut view = bindings::llama_token_data_array {
data: llamacpp.logprobs.as_mut_ptr(),
size: llamacpp.logprobs.len(),
selected: -1,
sorted: false,
};
unsafe {
bindings::llama_sampler_apply(self.chain, &mut view);
let logprob = *view.data.offset(view.selected as _);
bindings::llama_sampler_accept(self.chain, logprob.id);
(logprob.id, logprob.logit) // maybe p.ln() ?
}
}
}
@ -321,14 +365,12 @@ impl LlamacppBackend {
match timeout(conf.batch_timeout, rx.recv()).await {
Ok(None) => break, // closed
Ok(Some(request)) => {
if let Some(max_batch_size) = conf.max_batch_size {
if requests.len() + 1 == max_batch_size {
requests.push(request);
let _ = sync_tx.send(requests);
n_tokens = 0;
requests = Vec::new();
continue;
}
if requests.len() + 1 == conf.max_batch_size {
requests.push(request);
let _ = sync_tx.send(requests);
n_tokens = 0;
requests = Vec::new();
continue;
}
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
let _ = sync_tx.send(requests);
@ -378,6 +420,8 @@ impl LlamacppBackend {
true,
);
}
let mut pos = request.input_ids.len();
// TODO: close this loop :)
// TODO: move up for perf ?
@ -409,15 +453,12 @@ impl LlamacppBackend {
break;
},
};
let next = sampler.sample(&llamacpp);
n_tokens += llamacpp.batch.n_tokens as usize;
n_new_tokens += llamacpp.batch.n_tokens as usize;
let idx = llamacpp.batch.n_tokens as usize - 1;
let (next, logprob) = sampler.sample(&mut llamacpp, idx);
n_new_tokens += 1;
debug!("tokens: {n_tokens} new: {n_new_tokens}");
let logits = unsafe {
*bindings::llama_get_logits_ith(llamacpp.ctx, -1)
};
let kv_cache_used_cells = unsafe {
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
};
@ -437,7 +478,7 @@ impl LlamacppBackend {
let token = Token {
id: next as _,
text: piece,
logprob: logits as _,
logprob: logprob,
special: special,
};
let finish: Option<FinishReason> = {
@ -471,7 +512,9 @@ impl LlamacppBackend {
top_tokens: vec![],
}));
llamacpp.batch.n_tokens = 0;
llamacpp.batch_push(next, n_tokens as _, &[0], true);
// llamacpp.batch_clear_logits();
llamacpp.batch_push(next, pos as _, &[0], true);
pos += 1;
}
}
} // TODO remove this

View File

@ -30,7 +30,7 @@ struct Args {
/// Number of threads to use for inference.
#[clap(default_value = "1", long, env)]
n_threads: isize,
n_threads: usize,
#[clap(default_value = "true", long, env)]
/// Whether to use memory mapping.
@ -77,8 +77,8 @@ struct Args {
// max_waiting_tokens: usize,
/// Maximum number of requests per batch
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1", long, env)]
max_batch_size: usize,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
@ -145,12 +145,10 @@ async fn main() -> Result<(), RouterError> {
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if let Some(max_batch_size) = args.max_batch_size {
if max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if args.max_batch_total_tokens > args.n_ctx {
return Err(RouterError::ArgumentValidation(