mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Remove n_ctx
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
051ff2d5ce
commit
09a745f1b8
@ -130,7 +130,6 @@ impl LlamacppGGMLType {
|
|||||||
|
|
||||||
pub struct LlamacppConfig {
|
pub struct LlamacppConfig {
|
||||||
pub model_gguf: String,
|
pub model_gguf: String,
|
||||||
pub n_ctx: usize,
|
|
||||||
pub max_batch_total_tokens: usize,
|
pub max_batch_total_tokens: usize,
|
||||||
pub max_physical_batch_total_tokens: usize,
|
pub max_physical_batch_total_tokens: usize,
|
||||||
pub max_batch_size: usize,
|
pub max_batch_size: usize,
|
||||||
@ -206,7 +205,6 @@ struct Llamacpp {
|
|||||||
vocab: *const llamacpp::llama_vocab,
|
vocab: *const llamacpp::llama_vocab,
|
||||||
logprobs: Vec<llamacpp::llama_token_data>,
|
logprobs: Vec<llamacpp::llama_token_data>,
|
||||||
batch: llamacpp::llama_batch,
|
batch: llamacpp::llama_batch,
|
||||||
n_ctx: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" fn llamacpp_log_callback(
|
extern "C" fn llamacpp_log_callback(
|
||||||
@ -251,7 +249,7 @@ impl Llamacpp {
|
|||||||
}
|
}
|
||||||
let ctx = unsafe {
|
let ctx = unsafe {
|
||||||
let mut params = llamacpp::context_default_params();
|
let mut params = llamacpp::context_default_params();
|
||||||
params.n_ctx = conf.n_ctx as _;
|
params.n_ctx = conf.max_batch_total_tokens as _;
|
||||||
params.n_batch = conf.max_batch_total_tokens as _;
|
params.n_batch = conf.max_batch_total_tokens as _;
|
||||||
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
|
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
|
||||||
params.n_seq_max = conf.max_batch_size as _;
|
params.n_seq_max = conf.max_batch_size as _;
|
||||||
@ -268,8 +266,6 @@ impl Llamacpp {
|
|||||||
if ctx.is_null() {
|
if ctx.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
|
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
|
||||||
}
|
}
|
||||||
let n_ctx = unsafe { llamacpp::n_ctx(ctx) };
|
|
||||||
|
|
||||||
let vocab = unsafe {
|
let vocab = unsafe {
|
||||||
llamacpp::model_get_vocab(model)
|
llamacpp::model_get_vocab(model)
|
||||||
};
|
};
|
||||||
@ -291,7 +287,7 @@ impl Llamacpp {
|
|||||||
let batch = unsafe {
|
let batch = unsafe {
|
||||||
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
||||||
};
|
};
|
||||||
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
Ok(Llamacpp{model, ctx, vocab, logprobs, batch})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
|
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
|
||||||
@ -559,9 +555,6 @@ impl LlamacppBackend {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let kv_cache_used_cells = unsafe {
|
|
||||||
llamacpp::get_kv_cache_used_cells(llamacpp.ctx)
|
|
||||||
};
|
|
||||||
for seq in seqs.iter_mut() {
|
for seq in seqs.iter_mut() {
|
||||||
if !seq.running {
|
if !seq.running {
|
||||||
continue;
|
continue;
|
||||||
@ -595,8 +588,6 @@ impl LlamacppBackend {
|
|||||||
Some(FinishReason::EndOfSequenceToken)
|
Some(FinishReason::EndOfSequenceToken)
|
||||||
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
|
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
|
||||||
Some(FinishReason::Length)
|
Some(FinishReason::Length)
|
||||||
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
|
||||||
Some(FinishReason::Length) // TODO: check
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -24,10 +24,6 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
model_gguf: String, // TODO Option() with hf->gguf & quantize
|
model_gguf: String, // TODO Option() with hf->gguf & quantize
|
||||||
|
|
||||||
/// Context size for the model.
|
|
||||||
#[clap(default_value = "4096", long, env)]
|
|
||||||
n_ctx: usize,
|
|
||||||
|
|
||||||
/// Number of threads to use for generation.
|
/// Number of threads to use for generation.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
n_threads: Option<usize>,
|
n_threads: Option<usize>,
|
||||||
@ -198,11 +194,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if args.max_batch_total_tokens > args.n_ctx {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`max_batch_total_tokens` must be <= `n_ctx`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check if we use the same cache of Server
|
// TODO: check if we use the same cache of Server
|
||||||
// check if llamacpp is faster
|
// check if llamacpp is faster
|
||||||
@ -224,7 +215,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let (backend, ok, shutdown) = LlamacppBackend::new(
|
let (backend, ok, shutdown) = LlamacppBackend::new(
|
||||||
LlamacppConfig {
|
LlamacppConfig {
|
||||||
model_gguf: args.model_gguf,
|
model_gguf: args.model_gguf,
|
||||||
n_ctx: args.n_ctx,
|
|
||||||
n_threads: n_threads,
|
n_threads: n_threads,
|
||||||
n_threads_batch: n_threads_batch,
|
n_threads_batch: n_threads_batch,
|
||||||
n_gpu_layers: args.n_gpu_layers,
|
n_gpu_layers: args.n_gpu_layers,
|
||||||
|
Loading…
Reference in New Issue
Block a user