Handle ctx args & fix sampling

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-30 22:41:26 +00:00
parent a7b4b04cb5
commit 8d2dfdf668
No known key found for this signature in database
3 changed files with 80 additions and 39 deletions

View File

@ -18,7 +18,7 @@ RUN apt-get install -y \
# nvidia-cuda-toolkit # nvidia-cuda-toolkit
# -DGGML_CUDA=ON \ # -DGGML_CUDA=ON \
ENV LLAMA_VERSION=b4585 ENV LLAMA_VERSION=b4599
RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \ RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \
&& cd llama.cpp \ && cd llama.cpp \
&& cmake -B build \ && cmake -B build \

View File

@ -25,9 +25,9 @@ pub struct LlamacppConfig {
pub model_gguf: String, pub model_gguf: String,
pub n_ctx: usize, pub n_ctx: usize,
pub max_batch_total_tokens: usize, pub max_batch_total_tokens: usize,
pub max_batch_size: Option<usize>, pub max_batch_size: usize,
pub batch_timeout: Duration, pub batch_timeout: Duration,
pub n_threads: isize, pub n_threads: usize,
pub use_mmap: bool, pub use_mmap: bool,
pub use_mlock: bool, pub use_mlock: bool,
pub flash_attention: bool, pub flash_attention: bool,
@ -88,6 +88,7 @@ struct Llamacpp {
model: *mut bindings::llama_model, model: *mut bindings::llama_model,
ctx: *mut bindings::llama_context, ctx: *mut bindings::llama_context,
vocab: *const bindings::llama_vocab, vocab: *const bindings::llama_vocab,
logprobs: Vec<bindings::llama_token_data>,
batch: bindings::llama_batch, batch: bindings::llama_batch,
n_ctx: u32, n_ctx: u32,
} }
@ -125,8 +126,11 @@ impl Llamacpp {
let ctx = unsafe { let ctx = unsafe {
let mut params = bindings::llama_context_default_params(); let mut params = bindings::llama_context_default_params();
params.n_ctx = conf.n_ctx as _; params.n_ctx = conf.n_ctx as _;
params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ?
params.n_seq_max = conf.max_batch_size as _;
params.n_threads = conf.n_threads as _; params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads as _; params.n_threads_batch = conf.n_threads as _; // TODO ?
params.flash_attn = conf.flash_attention; params.flash_attn = conf.flash_attention;
params.no_perf = true; params.no_perf = true;
bindings::llama_init_from_model(model, params) bindings::llama_init_from_model(model, params)
@ -142,11 +146,30 @@ impl Llamacpp {
if vocab.is_null() { if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
} }
let n_tokens = unsafe {
bindings::llama_vocab_n_tokens(vocab)
};
let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens {
logprobs.push(bindings::llama_token_data {
id: token,
logit: 0.0,
p: 0.0,
});
}
let batch = unsafe { let batch = unsafe {
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1) bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
}; };
// TODO check batch Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch}) }
fn batch_clear_logits(&mut self) {
for n in 0..self.batch.n_tokens as usize{
unsafe {
*self.batch.logits.add(n) = 0 as i8;
}
}
} }
fn batch_push( fn batch_push(
@ -156,6 +179,7 @@ impl Llamacpp {
seq_ids: &[bindings::llama_seq_id], seq_ids: &[bindings::llama_seq_id],
logits: bool, logits: bool,
) { ) {
debug!("push {token} {pos} {logits}");
// TODO check evertyhing.. // TODO check evertyhing..
let n = self.batch.n_tokens as usize; let n = self.batch.n_tokens as usize;
@ -279,9 +303,29 @@ impl LlamacppSampler {
} }
} }
fn sample(&self, llamacpp: &Llamacpp) -> bindings::llama_token { fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) {
// use apply/accept ? let logits = unsafe {
unsafe { bindings::llama_sampler_sample(self.chain, llamacpp.ctx, -1) }// -1 ? bindings::llama_get_logits_ith(llamacpp.ctx, idx as _)
};
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = bindings::llama_token_data {
id: token as _,
logit: unsafe { *logits.offset(token as _) },
p: 0.0,
};
}
let mut view = bindings::llama_token_data_array {
data: llamacpp.logprobs.as_mut_ptr(),
size: llamacpp.logprobs.len(),
selected: -1,
sorted: false,
};
unsafe {
bindings::llama_sampler_apply(self.chain, &mut view);
let logprob = *view.data.offset(view.selected as _);
bindings::llama_sampler_accept(self.chain, logprob.id);
(logprob.id, logprob.logit) // maybe p.ln() ?
}
} }
} }
@ -321,15 +365,13 @@ impl LlamacppBackend {
match timeout(conf.batch_timeout, rx.recv()).await { match timeout(conf.batch_timeout, rx.recv()).await {
Ok(None) => break, // closed Ok(None) => break, // closed
Ok(Some(request)) => { Ok(Some(request)) => {
if let Some(max_batch_size) = conf.max_batch_size { if requests.len() + 1 == conf.max_batch_size {
if requests.len() + 1 == max_batch_size {
requests.push(request); requests.push(request);
let _ = sync_tx.send(requests); let _ = sync_tx.send(requests);
n_tokens = 0; n_tokens = 0;
requests = Vec::new(); requests = Vec::new();
continue; continue;
} }
}
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize { if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
let _ = sync_tx.send(requests); let _ = sync_tx.send(requests);
n_tokens = request.input_ids.len(); n_tokens = request.input_ids.len();
@ -378,6 +420,8 @@ impl LlamacppBackend {
true, true,
); );
} }
let mut pos = request.input_ids.len();
// TODO: close this loop :) // TODO: close this loop :)
// TODO: move up for perf ? // TODO: move up for perf ?
@ -409,15 +453,12 @@ impl LlamacppBackend {
break; break;
}, },
}; };
let next = sampler.sample(&llamacpp); let idx = llamacpp.batch.n_tokens as usize - 1;
n_tokens += llamacpp.batch.n_tokens as usize; let (next, logprob) = sampler.sample(&mut llamacpp, idx);
n_new_tokens += llamacpp.batch.n_tokens as usize; n_new_tokens += 1;
debug!("tokens: {n_tokens} new: {n_new_tokens}"); debug!("tokens: {n_tokens} new: {n_new_tokens}");
let logits = unsafe {
*bindings::llama_get_logits_ith(llamacpp.ctx, -1)
};
let kv_cache_used_cells = unsafe { let kv_cache_used_cells = unsafe {
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx) bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
}; };
@ -437,7 +478,7 @@ impl LlamacppBackend {
let token = Token { let token = Token {
id: next as _, id: next as _,
text: piece, text: piece,
logprob: logits as _, logprob: logprob,
special: special, special: special,
}; };
let finish: Option<FinishReason> = { let finish: Option<FinishReason> = {
@ -471,7 +512,9 @@ impl LlamacppBackend {
top_tokens: vec![], top_tokens: vec![],
})); }));
llamacpp.batch.n_tokens = 0; llamacpp.batch.n_tokens = 0;
llamacpp.batch_push(next, n_tokens as _, &[0], true); // llamacpp.batch_clear_logits();
llamacpp.batch_push(next, pos as _, &[0], true);
pos += 1;
} }
} }
} // TODO remove this } // TODO remove this

View File

@ -30,7 +30,7 @@ struct Args {
/// Number of threads to use for inference. /// Number of threads to use for inference.
#[clap(default_value = "1", long, env)] #[clap(default_value = "1", long, env)]
n_threads: isize, n_threads: usize,
#[clap(default_value = "true", long, env)] #[clap(default_value = "true", long, env)]
/// Whether to use memory mapping. /// Whether to use memory mapping.
@ -77,8 +77,8 @@ struct Args {
// max_waiting_tokens: usize, // max_waiting_tokens: usize,
/// Maximum number of requests per batch /// Maximum number of requests per batch
#[clap(long, env)] #[clap(default_value = "1", long, env)]
max_batch_size: Option<usize>, max_batch_size: usize,
/// The IP address to listen on /// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
@ -145,13 +145,11 @@ async fn main() -> Result<(), RouterError> {
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
)); ));
} }
if let Some(max_batch_size) = args.max_batch_size { if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
if max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
)); ));
} }
}
if args.max_batch_total_tokens > args.n_ctx { if args.max_batch_total_tokens > args.n_ctx {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_batch_total_tokens` must be <= `n_ctx`".to_string(), "`max_batch_total_tokens` must be <= `n_ctx`".to_string(),