mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-10-08 22:45:23 +00:00
Add llamacpp backend
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
88fd56f549
commit
95e221eece
926
Cargo.lock
generated
926
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -5,6 +5,7 @@ members = [
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
"backends/trtllm",
|
||||
"backends/llamacpp",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
|
77
Dockerfile_llamacpp
Normal file
77
Dockerfile_llamacpp
Normal file
@ -0,0 +1,77 @@
|
||||
FROM ubuntu:24.04 AS base
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3-venv \
|
||||
python3-pip
|
||||
|
||||
RUN python3 -m venv /venv
|
||||
ENV PATH="/venv/bin:$PATH"
|
||||
RUN pip3 install --no-cache-dir transformers
|
||||
|
||||
FROM base AS deps
|
||||
WORKDIR /deps
|
||||
|
||||
RUN apt-get install -y \
|
||||
clang cmake git
|
||||
|
||||
# nvidia-cuda-toolkit
|
||||
# -DGGML_CUDA=ON \
|
||||
|
||||
ENV LLAMA_VERSION=b4585
|
||||
RUN git clone --depth 1 -b ${LLAMA_VERSION} https://github.com/ggerganov/llama.cpp \
|
||||
&& cd llama.cpp \
|
||||
&& cmake -B build \
|
||||
-DCMAKE_INSTALL_PREFIX=/usr \
|
||||
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
|
||||
-DCMAKE_C_COMPILER=clang \
|
||||
-DCMAKE_CXX_COMPILER=clang++ \
|
||||
-DLLAMA_BUILD_COMMON=OFF \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
&& cmake --build build --config Release -j \
|
||||
&& cmake --install build
|
||||
|
||||
# ENV MIMALLOC_VERSION=v3.0.1
|
||||
# RUN git clone --depth 1 -b ${MIMALLOC_VERSION} https://github.com/microsoft/mimalloc \
|
||||
# && cd mimalloc \
|
||||
# && cmake -B build \
|
||||
# -DCMAKE_INSTALL_PREFIX=/usr \
|
||||
# -DCMAKE_INSTALL_LIBDIR=/usr/lib \
|
||||
# -DCMAKE_C_COMPILER=clang \
|
||||
# -DCMAKE_CXX_COMPILER=clang++ \
|
||||
# && cmake --build build --config Release -j \
|
||||
# && cmake --install build
|
||||
|
||||
RUN apt-get install -y \
|
||||
curl pkg-config libssl-dev
|
||||
|
||||
WORKDIR /app
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
RUN curl -sSf https://sh.rustup.rs | sh -s -- -y --no-modify-path --default-toolchain none
|
||||
ENV PATH="/root/.cargo/bin:$PATH"
|
||||
RUN cargo install cargo-chef --locked
|
||||
|
||||
FROM deps AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM deps AS builder
|
||||
COPY --from=planner /app/recipe.json recipe.json
|
||||
RUN cargo chef cook \
|
||||
--recipe-path recipe.json \
|
||||
--profile release-opt \
|
||||
--package text-generation-router-llamacpp
|
||||
COPY . .
|
||||
RUN cargo build \
|
||||
--profile release-opt \
|
||||
--package text-generation-router-llamacpp --frozen
|
||||
|
||||
FROM base AS runtime
|
||||
|
||||
COPY --from=deps /usr/lib/libllama.so /usr/lib/
|
||||
COPY --from=deps /usr/lib/libggml*.so /usr/lib/
|
||||
COPY --from=builder /app/target/release-opt/text-generation-router-llamacpp /bin/text-generation-launcher
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
2
backends/llamacpp/.cargo/config.toml
Normal file
2
backends/llamacpp/.cargo/config.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[build]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
20
backends/llamacpp/Cargo.toml
Normal file
20
backends/llamacpp/Cargo.toml
Normal file
@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "text-generation-router-llamacpp"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.71.1"
|
||||
pkg-config = "0.3.31"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.85"
|
||||
clap = "4.5.27"
|
||||
text-generation-router = { path = "../../router" }
|
||||
thiserror = "2.0.11"
|
||||
tokenizers.workspace = true
|
||||
tokio = "1.43.0"
|
||||
tokio-stream = "0.1.17"
|
||||
tracing = "0.1.41"
|
20
backends/llamacpp/build.rs
Normal file
20
backends/llamacpp/build.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() {
|
||||
let bindings = bindgen::Builder::default()
|
||||
.header("src/wrapper.h")
|
||||
.prepend_enum_name(false)
|
||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
||||
.generate()
|
||||
.expect("Unable to generate bindings");
|
||||
|
||||
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
bindings
|
||||
.write_to_file(out_path.join("bindings.rs"))
|
||||
.expect("Couldn't write bindings!");
|
||||
|
||||
pkg_config::Config::new().probe("llama").unwrap();
|
||||
|
||||
println!("cargo::rerun-if-changed=build.rs");
|
||||
}
|
434
backends/llamacpp/src/backend.rs
Normal file
434
backends/llamacpp/src/backend.rs
Normal file
@ -0,0 +1,434 @@
|
||||
mod bindings {
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(non_snake_case)]
|
||||
#![allow(dead_code)]
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
}
|
||||
use async_trait::async_trait;
|
||||
use std::ffi::CString;
|
||||
use std::sync::Once;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{ValidGenerateRequest};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::{watch, oneshot};
|
||||
use tokio::task::spawn_blocking;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, info, warn, error, trace};
|
||||
use tracing::{instrument};
|
||||
|
||||
pub struct LlamacppConfig {
|
||||
pub model_gguf: String,
|
||||
pub n_ctx: u32,
|
||||
pub n_threads: i32,
|
||||
pub use_mmap: bool,
|
||||
pub use_mlock: bool,
|
||||
pub flash_attention: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LlamacppRequest {
|
||||
input_ids: Vec<i32>,
|
||||
top_k: i32,
|
||||
top_p: f32,
|
||||
typical_p: f32,
|
||||
min_keep: usize,
|
||||
temp: f32,
|
||||
seed: u32,
|
||||
penalty_last_n: i32,
|
||||
penalty_repeat: f32,
|
||||
penalty_freq: f32,
|
||||
penalty_present: f32,
|
||||
max_new_tokens: usize,
|
||||
tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
time: Instant,
|
||||
}
|
||||
|
||||
pub struct LlamacppBackend {
|
||||
tx: UnboundedSender<LlamacppRequest>,
|
||||
status: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl LlamacppRequest {
|
||||
fn new(
|
||||
from: &ValidGenerateRequest,
|
||||
tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
) -> Option<Self>{
|
||||
if let Some(input_ids) = from.input_ids.as_ref() {
|
||||
Some(LlamacppRequest {
|
||||
input_ids: input_ids.iter().map(|&x| x as i32).collect(),
|
||||
top_k: from.parameters.top_k as _,
|
||||
top_p: from.parameters.top_p as _,
|
||||
typical_p: from.parameters.typical_p as _,
|
||||
min_keep: 0, // disabled
|
||||
temp: from.parameters.temperature as _,
|
||||
seed: from.parameters.seed as _,
|
||||
penalty_last_n: -1, // 0 = disabled, -1 = context size
|
||||
penalty_repeat: from.parameters.repetition_penalty as _,
|
||||
penalty_freq: from.parameters.frequency_penalty as _,
|
||||
penalty_present: 0.0, // disabled
|
||||
max_new_tokens: from.stopping_parameters.max_new_tokens as _,
|
||||
tx: tx,
|
||||
time: Instant::now(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Llamacpp {
|
||||
model: *mut bindings::llama_model,
|
||||
ctx: *mut bindings::llama_context,
|
||||
vocab: *const bindings::llama_vocab,
|
||||
n_ctx: u32,
|
||||
}
|
||||
|
||||
extern "C" fn llamacpp_log_callback(
|
||||
level: bindings::ggml_log_level,
|
||||
msg: *const std::os::raw::c_char,
|
||||
_user_data: *mut std::os::raw::c_void,
|
||||
) {
|
||||
let cmsg = unsafe { std::ffi::CStr::from_ptr(msg) };
|
||||
let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
|
||||
|
||||
match level {
|
||||
bindings::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
|
||||
bindings::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
|
||||
bindings::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
|
||||
bindings::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
|
||||
_ => trace!(target: "llamacpp", "{}", rmsg),
|
||||
}
|
||||
}
|
||||
|
||||
impl Llamacpp {
|
||||
fn new(conf: LlamacppConfig) -> Result<Self, BackendError> {
|
||||
let gguf = CString::new(conf.model_gguf)?;
|
||||
|
||||
let model = unsafe {
|
||||
let mut params = bindings::llama_model_default_params();
|
||||
params.use_mmap = conf.use_mmap;
|
||||
params.use_mlock = conf.use_mlock;
|
||||
bindings::llama_model_load_from_file(gguf.as_ptr(), params)
|
||||
};
|
||||
if model.is_null() {
|
||||
return Err(BackendError::Llamacpp("Failed to load model".to_string()))
|
||||
}
|
||||
let ctx = unsafe {
|
||||
let mut params = bindings::llama_context_default_params();
|
||||
params.n_ctx = conf.n_ctx;
|
||||
params.n_threads = conf.n_threads;
|
||||
params.n_threads_batch = conf.n_threads;
|
||||
params.flash_attn = conf.flash_attention;
|
||||
params.no_perf = true;
|
||||
bindings::llama_init_from_model(model, params)
|
||||
};
|
||||
if ctx.is_null() {
|
||||
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
|
||||
}
|
||||
let n_ctx = unsafe { bindings::llama_n_ctx(ctx) };
|
||||
|
||||
let vocab = unsafe {
|
||||
bindings::llama_model_get_vocab(model)
|
||||
};
|
||||
if vocab.is_null() {
|
||||
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
||||
}
|
||||
Ok(Llamacpp{model, ctx, vocab, n_ctx})
|
||||
}
|
||||
// useless ?
|
||||
fn warmup(&self) {
|
||||
let mut buf: Vec<bindings::llama_token> = Vec::new();
|
||||
|
||||
let bos = unsafe {
|
||||
bindings::llama_vocab_bos(self.vocab)
|
||||
};
|
||||
if bos != bindings::LLAMA_TOKEN_NULL {
|
||||
buf.push(bos);
|
||||
}
|
||||
let eos = unsafe {
|
||||
bindings::llama_vocab_eos(self.vocab)
|
||||
};
|
||||
if eos != bindings::LLAMA_TOKEN_NULL {
|
||||
buf.push(eos);
|
||||
}
|
||||
if buf.is_empty() {
|
||||
warn!("Warmup failed: no bos/eos...");
|
||||
return;
|
||||
}
|
||||
let batch = unsafe {
|
||||
bindings::llama_batch_get_one(buf.as_ptr() as _, buf.len() as _)
|
||||
};
|
||||
if unsafe { bindings::llama_decode(self.ctx, batch) } != 0 {
|
||||
error!("Warmup failed: llama_decode() returned an error");
|
||||
}
|
||||
unsafe {
|
||||
bindings::llama_kv_cache_clear(self.ctx);
|
||||
bindings::llama_synchronize(self.ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Llamacpp {
|
||||
fn drop(&mut self) {
|
||||
if !self.ctx.is_null() {
|
||||
unsafe { bindings::llama_free(self.ctx) };
|
||||
}
|
||||
if !self.model.is_null() {
|
||||
unsafe { bindings::llama_model_free(self.model) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamacppSampler {
|
||||
chain: *mut bindings::llama_sampler,
|
||||
}
|
||||
|
||||
impl LlamacppSampler {
|
||||
fn new(req: &LlamacppRequest) -> Option<Self> {
|
||||
let chain = unsafe {
|
||||
let params = bindings::llama_sampler_chain_default_params();
|
||||
bindings::llama_sampler_chain_init(params)
|
||||
};
|
||||
if chain.is_null() {
|
||||
error!("Failed to init sampler");
|
||||
return None;
|
||||
}
|
||||
let top_k = unsafe {
|
||||
bindings::llama_sampler_init_top_k(req.top_k)
|
||||
};
|
||||
let top_p = unsafe {
|
||||
bindings::llama_sampler_init_top_p(req.top_p, req.min_keep)
|
||||
};
|
||||
let typical_p = unsafe {
|
||||
bindings::llama_sampler_init_typical(req.typical_p, req.min_keep)
|
||||
};
|
||||
let temp = unsafe {
|
||||
bindings::llama_sampler_init_temp(req.temp)
|
||||
};
|
||||
let penalties = unsafe {
|
||||
bindings::llama_sampler_init_penalties(
|
||||
req.penalty_last_n,
|
||||
req.penalty_repeat,
|
||||
req.penalty_freq,
|
||||
req.penalty_present,
|
||||
)
|
||||
};
|
||||
let dist = unsafe {
|
||||
bindings::llama_sampler_init_dist(req.seed)
|
||||
};
|
||||
let mut failed = false;
|
||||
|
||||
for (k, v) in &[("top_k", top_k),
|
||||
("top_p", top_p),
|
||||
("typical_p", typical_p),
|
||||
("temp", temp),
|
||||
("penalties", penalties),
|
||||
("dist", dist)] {
|
||||
if v.is_null() {
|
||||
error!("Failed to init {k} sampler");
|
||||
failed = true;
|
||||
} else {
|
||||
unsafe { bindings::llama_sampler_chain_add(chain, *v) };
|
||||
}
|
||||
}
|
||||
if failed {
|
||||
None
|
||||
} else {
|
||||
Some(LlamacppSampler{chain})
|
||||
}
|
||||
}
|
||||
|
||||
fn sample(&self, llamacpp: &Llamacpp) -> bindings::llama_token {
|
||||
// use apply/accept ?
|
||||
unsafe { bindings::llama_sampler_sample(self.chain, llamacpp.ctx, -1) }// -1 ?
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LlamacppSampler {
|
||||
fn drop(&mut self) {
|
||||
if !self.chain.is_null() {
|
||||
unsafe { bindings::llama_sampler_free(self.chain) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
impl LlamacppBackend {
|
||||
pub fn new(
|
||||
conf: LlamacppConfig,
|
||||
tokenizer: Tokenizer,
|
||||
) -> (Self, oneshot::Receiver<Result<(),BackendError>>) {
|
||||
|
||||
// Setup llama & export logs, once and for all
|
||||
INIT.call_once(|| unsafe {
|
||||
bindings::llama_log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
|
||||
bindings::llama_backend_init();
|
||||
bindings::llama_numa_init(bindings::GGML_NUMA_STRATEGY_NUMACTL); // TODO add option & test
|
||||
});
|
||||
|
||||
let (status_tx, status_rx) = watch::channel(false);
|
||||
let (ok_tx, ok_rx) = oneshot::channel();
|
||||
let (tx, mut rx) = unbounded_channel::<LlamacppRequest>();
|
||||
|
||||
spawn_blocking(move || {
|
||||
let llamacpp = match Llamacpp::new(conf) {
|
||||
Ok(v) => { let _ = ok_tx.send(Ok(())); v },
|
||||
Err(e) => { let _ = ok_tx.send(Err(e)); return; },
|
||||
};
|
||||
llamacpp.warmup();
|
||||
|
||||
let vocab = tokenizer.get_added_vocabulary();
|
||||
|
||||
// health() returns true
|
||||
let _ = status_tx.send(true);
|
||||
|
||||
while let Some(request) = rx.blocking_recv() {
|
||||
debug!("Request: {:?}", request);
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
// TODO: do a real batch
|
||||
let mut batch = unsafe {
|
||||
bindings::llama_batch_get_one(
|
||||
request.input_ids.as_ptr() as _,
|
||||
request.input_ids.len() as _,
|
||||
)
|
||||
};
|
||||
// TODO: move up for perf ?
|
||||
let sampler = match LlamacppSampler::new(&request) {
|
||||
Some(sampler) => sampler,
|
||||
_ => {
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
continue;
|
||||
},
|
||||
};
|
||||
let mut text = String::with_capacity(1024);
|
||||
let mut n_tokens: usize = 0;
|
||||
|
||||
loop {
|
||||
debug!(?batch);
|
||||
match unsafe { bindings::llama_decode(llamacpp.ctx, batch) } {
|
||||
0 => { },
|
||||
1 => {
|
||||
unsafe {
|
||||
// TODO: seq_rm & seq_add if model is compatible
|
||||
bindings::llama_kv_cache_clear(llamacpp.ctx);
|
||||
}
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
continue;
|
||||
},
|
||||
_ => {
|
||||
debug!("decode return <0");
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
break;
|
||||
},
|
||||
};
|
||||
let mut next = sampler.sample(&llamacpp);
|
||||
n_tokens += 1;
|
||||
debug!(?n_tokens);
|
||||
|
||||
let logits = unsafe {
|
||||
*bindings::llama_get_logits_ith(llamacpp.ctx, -1)
|
||||
};
|
||||
let kv_cache_used_cells = unsafe {
|
||||
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
||||
};
|
||||
let piece = match tokenizer.decode(&[next as u32], false) {
|
||||
Ok(piece) => piece,
|
||||
Err(e) => {
|
||||
error!("Failed to decode token: {e}");
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
break;
|
||||
},
|
||||
};
|
||||
let special = vocab.is_special_token(&piece);
|
||||
|
||||
if !special {
|
||||
text.push_str(&piece);
|
||||
}
|
||||
let token = Token {
|
||||
id: next as _,
|
||||
text: piece,
|
||||
logprob: logits as _,
|
||||
special: special,
|
||||
};
|
||||
let finish: Option<FinishReason> = {
|
||||
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
||||
Some(FinishReason::EndOfSequenceToken)
|
||||
} else if n_tokens == request.max_new_tokens {
|
||||
Some(FinishReason::Length)
|
||||
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
||||
Some(FinishReason::Length) // TODO: check
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
if let Some(reason) = finish {
|
||||
let _ = request.tx.send(Ok(InferStreamResponse::End {
|
||||
token: token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: text,
|
||||
generated_tokens: n_tokens as _,
|
||||
finish_reason: reason,
|
||||
seed: Some(request.seed as _),
|
||||
},
|
||||
start: start_time,
|
||||
queued: request.time,
|
||||
}));
|
||||
break;
|
||||
}
|
||||
let _ = request.tx.send(Ok(InferStreamResponse::Intermediate {
|
||||
token: token,
|
||||
top_tokens: vec![],
|
||||
}));
|
||||
batch = unsafe {
|
||||
bindings::llama_batch_get_one(&mut next, 1)
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
(Self{tx, status: status_rx}, ok_rx)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for LlamacppBackend {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
debug!(?request);
|
||||
let (tx, rx) = unbounded_channel::<Result<InferStreamResponse, InferError>>();
|
||||
match LlamacppRequest::new(&request, tx) {
|
||||
Some(v) => match self.tx.send(v) {
|
||||
Err(e) => Err(InferError::GenerationError(e.to_string())),
|
||||
_ => Ok(UnboundedReceiverStream::new(rx)),
|
||||
},
|
||||
_ => Err(InferError::GenerationError("Bad request".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn health(&self, _: bool) -> bool {
|
||||
*self.status.borrow()
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"llamacpp"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum BackendError {
|
||||
#[error("CString error: {0}")]
|
||||
CStringError(#[from] std::ffi::NulError),
|
||||
#[error("Llamacpp error: {0}")]
|
||||
Llamacpp(String),
|
||||
}
|
210
backends/llamacpp/src/main.rs
Normal file
210
backends/llamacpp/src/main.rs
Normal file
@ -0,0 +1,210 @@
|
||||
mod backend;
|
||||
|
||||
use backend::{LlamacppConfig, LlamacppBackend, BackendError};
|
||||
use clap::{Parser};
|
||||
use text_generation_router::{logging, server, usage_stats};
|
||||
use thiserror::Error;
|
||||
use tokenizers::{Tokenizer, FromPretrainedParameters};
|
||||
use tokio::sync::oneshot::error::RecvError;
|
||||
use tracing::error;
|
||||
|
||||
/// Backend Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Name of the model to load.
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
|
||||
/// Revision of the model.
|
||||
#[clap(default_value = "main", long, env)]
|
||||
revision: String,
|
||||
|
||||
/// Path to the GGUF model file to be used for inference.
|
||||
#[clap(long, env)]
|
||||
model_gguf: String, // TODO Option() with hf->gguf & quantize
|
||||
|
||||
/// Context size for the model.
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
n_ctx: u32,
|
||||
|
||||
/// Number of threads to use for inference.
|
||||
#[clap(default_value = "1", long, env)]
|
||||
n_threads: i32,
|
||||
|
||||
#[clap(default_value = "true", long, env)]
|
||||
/// Whether to use memory mapping.
|
||||
use_mmap: bool,
|
||||
|
||||
#[clap(default_value = "false", long, env)]
|
||||
/// Whether to use memory locking.
|
||||
use_mlock: bool,
|
||||
|
||||
/// Enable flash attention for faster inference. (EXPERIMENTAL)
|
||||
#[clap(default_value = "false", long, env)]
|
||||
flash_attention: bool,
|
||||
|
||||
/// TODO
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
|
||||
/// Maximum number of input tokens allowed per request.
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
|
||||
/// Maximum total tokens (input + output) allowed per request.
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
|
||||
// #[clap(default_value = "1.2", long, env)]
|
||||
// waiting_served_ratio: f32,
|
||||
// #[clap(default_value = "4096", long, env)]
|
||||
// max_batch_prefill_tokens: u32,
|
||||
// #[clap(long, env)]
|
||||
// max_batch_total_tokens: Option<u32>,
|
||||
// #[clap(default_value = "20", long, env)]
|
||||
// max_waiting_tokens: usize,
|
||||
// #[clap(long, env)]
|
||||
// max_batch_size: Option<usize>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
|
||||
/// The port to listen on.
|
||||
#[clap(default_value = "3001", long, short, env)]
|
||||
port: u16,
|
||||
|
||||
// #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
// master_shard_uds_path: String,
|
||||
// #[clap(long, env)]
|
||||
// tokenizer_name: String,
|
||||
// #[clap(long, env)]
|
||||
// tokenizer_config_path: Option<String>,
|
||||
// #[clap(long, env, value_enum)]
|
||||
// trust_remote_code: bool,
|
||||
// #[clap(long, env)]
|
||||
// api_key: Option<String>,
|
||||
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
#[clap(default_value = "2000000", long, env)]
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
let args = Args::parse();
|
||||
|
||||
logging::init_logging(
|
||||
args.otlp_endpoint,
|
||||
args.otlp_service_name,
|
||||
args.json_output
|
||||
);
|
||||
|
||||
if args.max_input_tokens >= args.max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: check if we use the same cache of Server
|
||||
// check if llamacpp is faster
|
||||
let tokenizer = {
|
||||
let token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
let params = FromPretrainedParameters {
|
||||
revision: args.revision.clone(),
|
||||
token: token,
|
||||
..Default::default()
|
||||
};
|
||||
Tokenizer::from_pretrained(
|
||||
args.model_id.clone(),
|
||||
Some(params)
|
||||
)?
|
||||
};
|
||||
|
||||
let (backend, ok) = LlamacppBackend::new(
|
||||
LlamacppConfig {
|
||||
model_gguf: args.model_gguf,
|
||||
n_ctx: args.n_ctx,
|
||||
n_threads: args.n_threads,
|
||||
use_mmap: args.use_mmap,
|
||||
use_mlock: args.use_mlock,
|
||||
flash_attention: args.flash_attention,
|
||||
},
|
||||
tokenizer,
|
||||
);
|
||||
ok.await??;
|
||||
|
||||
server::run(
|
||||
backend,
|
||||
args.max_concurrent_requests,
|
||||
args.max_best_of,
|
||||
args.max_stop_sequences,
|
||||
args.max_top_n_tokens,
|
||||
args.max_input_tokens,
|
||||
args.max_total_tokens,
|
||||
args.validation_workers,
|
||||
None, // api_key
|
||||
args.model_id, // tokenizer_name
|
||||
args.tokenizer_config_path,
|
||||
Some(args.revision),
|
||||
false, // trust_remote_code
|
||||
args.hostname,
|
||||
args.port,
|
||||
args.cors_allow_origin,
|
||||
args.ngrok,
|
||||
args.ngrok_authtoken,
|
||||
args.ngrok_edge,
|
||||
args.disable_grammar_support,
|
||||
args.max_client_batch_size,
|
||||
args.usage_stats,
|
||||
args.payload_limit,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(#[from] tokenizers::Error),
|
||||
#[error("Backend error: {0}")]
|
||||
Backend(#[from] BackendError),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Recv error: {0}")]
|
||||
RecvError(#[from] RecvError),
|
||||
}
|
1
backends/llamacpp/src/wrapper.h
Normal file
1
backends/llamacpp/src/wrapper.h
Normal file
@ -0,0 +1 @@
|
||||
#include <llama.h>
|
Loading…
Reference in New Issue
Block a user