mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Handle ctx args & fix sampling
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
a7b4b04cb5
commit
8d2dfdf668
@ -18,7 +18,7 @@ RUN apt-get install -y \
|
|||||||
# nvidia-cuda-toolkit
|
# nvidia-cuda-toolkit
|
||||||
# -DGGML_CUDA=ON \
|
# -DGGML_CUDA=ON \
|
||||||
|
|
||||||
ENV LLAMA_VERSION=b4585
|
ENV LLAMA_VERSION=b4599
|
||||||
RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \
|
RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \
|
||||||
&& cd llama.cpp \
|
&& cd llama.cpp \
|
||||||
&& cmake -B build \
|
&& cmake -B build \
|
||||||
|
@ -25,9 +25,9 @@ pub struct LlamacppConfig {
|
|||||||
pub model_gguf: String,
|
pub model_gguf: String,
|
||||||
pub n_ctx: usize,
|
pub n_ctx: usize,
|
||||||
pub max_batch_total_tokens: usize,
|
pub max_batch_total_tokens: usize,
|
||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: usize,
|
||||||
pub batch_timeout: Duration,
|
pub batch_timeout: Duration,
|
||||||
pub n_threads: isize,
|
pub n_threads: usize,
|
||||||
pub use_mmap: bool,
|
pub use_mmap: bool,
|
||||||
pub use_mlock: bool,
|
pub use_mlock: bool,
|
||||||
pub flash_attention: bool,
|
pub flash_attention: bool,
|
||||||
@ -88,6 +88,7 @@ struct Llamacpp {
|
|||||||
model: *mut bindings::llama_model,
|
model: *mut bindings::llama_model,
|
||||||
ctx: *mut bindings::llama_context,
|
ctx: *mut bindings::llama_context,
|
||||||
vocab: *const bindings::llama_vocab,
|
vocab: *const bindings::llama_vocab,
|
||||||
|
logprobs: Vec<bindings::llama_token_data>,
|
||||||
batch: bindings::llama_batch,
|
batch: bindings::llama_batch,
|
||||||
n_ctx: u32,
|
n_ctx: u32,
|
||||||
}
|
}
|
||||||
@ -125,8 +126,11 @@ impl Llamacpp {
|
|||||||
let ctx = unsafe {
|
let ctx = unsafe {
|
||||||
let mut params = bindings::llama_context_default_params();
|
let mut params = bindings::llama_context_default_params();
|
||||||
params.n_ctx = conf.n_ctx as _;
|
params.n_ctx = conf.n_ctx as _;
|
||||||
|
params.n_batch = conf.max_batch_total_tokens as _;
|
||||||
|
params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ?
|
||||||
|
params.n_seq_max = conf.max_batch_size as _;
|
||||||
params.n_threads = conf.n_threads as _;
|
params.n_threads = conf.n_threads as _;
|
||||||
params.n_threads_batch = conf.n_threads as _;
|
params.n_threads_batch = conf.n_threads as _; // TODO ?
|
||||||
params.flash_attn = conf.flash_attention;
|
params.flash_attn = conf.flash_attention;
|
||||||
params.no_perf = true;
|
params.no_perf = true;
|
||||||
bindings::llama_init_from_model(model, params)
|
bindings::llama_init_from_model(model, params)
|
||||||
@ -142,11 +146,30 @@ impl Llamacpp {
|
|||||||
if vocab.is_null() {
|
if vocab.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
||||||
}
|
}
|
||||||
|
let n_tokens = unsafe {
|
||||||
|
bindings::llama_vocab_n_tokens(vocab)
|
||||||
|
};
|
||||||
|
let mut logprobs = Vec::with_capacity(n_tokens as usize);
|
||||||
|
|
||||||
|
for token in 0..n_tokens {
|
||||||
|
logprobs.push(bindings::llama_token_data {
|
||||||
|
id: token,
|
||||||
|
logit: 0.0,
|
||||||
|
p: 0.0,
|
||||||
|
});
|
||||||
|
}
|
||||||
let batch = unsafe {
|
let batch = unsafe {
|
||||||
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
||||||
};
|
};
|
||||||
// TODO check batch
|
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
||||||
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
|
}
|
||||||
|
|
||||||
|
fn batch_clear_logits(&mut self) {
|
||||||
|
for n in 0..self.batch.n_tokens as usize{
|
||||||
|
unsafe {
|
||||||
|
*self.batch.logits.add(n) = 0 as i8;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn batch_push(
|
fn batch_push(
|
||||||
@ -156,6 +179,7 @@ impl Llamacpp {
|
|||||||
seq_ids: &[bindings::llama_seq_id],
|
seq_ids: &[bindings::llama_seq_id],
|
||||||
logits: bool,
|
logits: bool,
|
||||||
) {
|
) {
|
||||||
|
debug!("push {token} {pos} {logits}");
|
||||||
// TODO check evertyhing..
|
// TODO check evertyhing..
|
||||||
let n = self.batch.n_tokens as usize;
|
let n = self.batch.n_tokens as usize;
|
||||||
|
|
||||||
@ -279,9 +303,29 @@ impl LlamacppSampler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, llamacpp: &Llamacpp) -> bindings::llama_token {
|
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) {
|
||||||
// use apply/accept ?
|
let logits = unsafe {
|
||||||
unsafe { bindings::llama_sampler_sample(self.chain, llamacpp.ctx, -1) }// -1 ?
|
bindings::llama_get_logits_ith(llamacpp.ctx, idx as _)
|
||||||
|
};
|
||||||
|
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
|
||||||
|
*logprob = bindings::llama_token_data {
|
||||||
|
id: token as _,
|
||||||
|
logit: unsafe { *logits.offset(token as _) },
|
||||||
|
p: 0.0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let mut view = bindings::llama_token_data_array {
|
||||||
|
data: llamacpp.logprobs.as_mut_ptr(),
|
||||||
|
size: llamacpp.logprobs.len(),
|
||||||
|
selected: -1,
|
||||||
|
sorted: false,
|
||||||
|
};
|
||||||
|
unsafe {
|
||||||
|
bindings::llama_sampler_apply(self.chain, &mut view);
|
||||||
|
let logprob = *view.data.offset(view.selected as _);
|
||||||
|
bindings::llama_sampler_accept(self.chain, logprob.id);
|
||||||
|
(logprob.id, logprob.logit) // maybe p.ln() ?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,15 +365,13 @@ impl LlamacppBackend {
|
|||||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||||
Ok(None) => break, // closed
|
Ok(None) => break, // closed
|
||||||
Ok(Some(request)) => {
|
Ok(Some(request)) => {
|
||||||
if let Some(max_batch_size) = conf.max_batch_size {
|
if requests.len() + 1 == conf.max_batch_size {
|
||||||
if requests.len() + 1 == max_batch_size {
|
|
||||||
requests.push(request);
|
requests.push(request);
|
||||||
let _ = sync_tx.send(requests);
|
let _ = sync_tx.send(requests);
|
||||||
n_tokens = 0;
|
n_tokens = 0;
|
||||||
requests = Vec::new();
|
requests = Vec::new();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
|
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
|
||||||
let _ = sync_tx.send(requests);
|
let _ = sync_tx.send(requests);
|
||||||
n_tokens = request.input_ids.len();
|
n_tokens = request.input_ids.len();
|
||||||
@ -378,6 +420,8 @@ impl LlamacppBackend {
|
|||||||
true,
|
true,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let mut pos = request.input_ids.len();
|
||||||
|
|
||||||
// TODO: close this loop :)
|
// TODO: close this loop :)
|
||||||
|
|
||||||
// TODO: move up for perf ?
|
// TODO: move up for perf ?
|
||||||
@ -409,15 +453,12 @@ impl LlamacppBackend {
|
|||||||
break;
|
break;
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let next = sampler.sample(&llamacpp);
|
let idx = llamacpp.batch.n_tokens as usize - 1;
|
||||||
n_tokens += llamacpp.batch.n_tokens as usize;
|
let (next, logprob) = sampler.sample(&mut llamacpp, idx);
|
||||||
n_new_tokens += llamacpp.batch.n_tokens as usize;
|
n_new_tokens += 1;
|
||||||
|
|
||||||
debug!("tokens: {n_tokens} new: {n_new_tokens}");
|
debug!("tokens: {n_tokens} new: {n_new_tokens}");
|
||||||
|
|
||||||
let logits = unsafe {
|
|
||||||
*bindings::llama_get_logits_ith(llamacpp.ctx, -1)
|
|
||||||
};
|
|
||||||
let kv_cache_used_cells = unsafe {
|
let kv_cache_used_cells = unsafe {
|
||||||
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
||||||
};
|
};
|
||||||
@ -437,7 +478,7 @@ impl LlamacppBackend {
|
|||||||
let token = Token {
|
let token = Token {
|
||||||
id: next as _,
|
id: next as _,
|
||||||
text: piece,
|
text: piece,
|
||||||
logprob: logits as _,
|
logprob: logprob,
|
||||||
special: special,
|
special: special,
|
||||||
};
|
};
|
||||||
let finish: Option<FinishReason> = {
|
let finish: Option<FinishReason> = {
|
||||||
@ -471,7 +512,9 @@ impl LlamacppBackend {
|
|||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
}));
|
}));
|
||||||
llamacpp.batch.n_tokens = 0;
|
llamacpp.batch.n_tokens = 0;
|
||||||
llamacpp.batch_push(next, n_tokens as _, &[0], true);
|
// llamacpp.batch_clear_logits();
|
||||||
|
llamacpp.batch_push(next, pos as _, &[0], true);
|
||||||
|
pos += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // TODO remove this
|
} // TODO remove this
|
||||||
|
@ -30,7 +30,7 @@ struct Args {
|
|||||||
|
|
||||||
/// Number of threads to use for inference.
|
/// Number of threads to use for inference.
|
||||||
#[clap(default_value = "1", long, env)]
|
#[clap(default_value = "1", long, env)]
|
||||||
n_threads: isize,
|
n_threads: usize,
|
||||||
|
|
||||||
#[clap(default_value = "true", long, env)]
|
#[clap(default_value = "true", long, env)]
|
||||||
/// Whether to use memory mapping.
|
/// Whether to use memory mapping.
|
||||||
@ -77,8 +77,8 @@ struct Args {
|
|||||||
// max_waiting_tokens: usize,
|
// max_waiting_tokens: usize,
|
||||||
|
|
||||||
/// Maximum number of requests per batch
|
/// Maximum number of requests per batch
|
||||||
#[clap(long, env)]
|
#[clap(default_value = "1", long, env)]
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: usize,
|
||||||
|
|
||||||
/// The IP address to listen on
|
/// The IP address to listen on
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
@ -145,13 +145,11 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if let Some(max_batch_size) = args.max_batch_size {
|
if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
|
||||||
if max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if args.max_batch_total_tokens > args.n_ctx {
|
if args.max_batch_total_tokens > args.n_ctx {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`max_batch_total_tokens` must be <= `n_ctx`".to_string(),
|
"`max_batch_total_tokens` must be <= `n_ctx`".to_string(),
|
||||||
|
Loading…
Reference in New Issue
Block a user