Update the llamacpp backend (#3022)

* Build faster

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Make --model-gguf optional

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Enable mmap, offload_kqv & flash_attention by default

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Better error message

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update installed packages

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Save gguf in models/MODEL_ID/model.gguf

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix build with Mach-O

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Quantize without llama-quantize

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp and switch to ggml-org

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove make-gguf.sh

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update Cargo.lock

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Support HF_HUB_USER_AGENT_ORIGIN

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --build-arg llamacpp_native & llamacpp_cpu_arm_arch

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-03-11 09:19:01 +01:00 committed by GitHub
parent dc5f05f8e6
commit 094975c3a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 202 additions and 64 deletions

1
Cargo.lock generated
View File

@ -4754,6 +4754,7 @@ dependencies = [
"async-trait",
"bindgen 0.71.1",
"clap 4.5.30",
"hf-hub 0.3.2",
"num_cpus",
"pkg-config",
"text-generation-router",

View File

@ -1,13 +1,15 @@
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
ARG llamacpp_version=b4651
ARG llamacpp_version=b4827
ARG llamacpp_cuda=OFF
ARG llamacpp_native=ON
ARG llamacpp_cpu_arm_arch=native
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
WORKDIR /opt/src
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt install -y \
RUN apt update && apt upgrade -y && apt install -y \
clang \
cmake \
curl \
@ -17,9 +19,10 @@ RUN apt update && apt install -y \
pkg-config \
tar
ADD https://github.com/ggerganov/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN tar -xzf ${llamacpp_version}.tar.gz \
&& cd llama.cpp-${llamacpp_version} \
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN mkdir -p llama.cpp \
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
&& cd llama.cpp \
&& cmake -B build \
-DCMAKE_INSTALL_PREFIX=/usr \
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
@ -27,6 +30,8 @@ RUN tar -xzf ${llamacpp_version}.tar.gz \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
-DGGML_CUDA=${llamacpp_cuda} \
-DGGML_NATIVE=${llamacpp_native} \
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
@ -48,16 +53,18 @@ FROM deps AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook \
--recipe-path recipe.json \
--profile release-opt \
--profile release \
--package text-generation-router-llamacpp
COPY . .
RUN cargo build \
--profile release-opt \
--profile release \
--package text-generation-router-llamacpp --frozen
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
WORKDIR /app
RUN apt update && apt install -y \
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
python3-venv \
python3-pip
@ -65,11 +72,16 @@ RUN python3 -m venv /venv
ENV PATH="/venv/bin:$PATH"
COPY backends/llamacpp/requirements.txt requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
RUN pip3 install --no-cache-dir \
-r requirements.txt \
-e gguf-py
COPY --from=builder /usr/lib/libllama.so /usr/lib/
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
COPY --from=builder /app/target/release-opt/text-generation-router-llamacpp /usr/bin/
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
ENV HF_HUB_ENABLE_HF_TRANSFER=1

View File

@ -12,10 +12,11 @@ pkg-config = "0.3.31"
[dependencies]
async-trait = "0.1.85"
clap = "4.5.27"
hf-hub.workspace = true
num_cpus = "1.16.0"
text-generation-router = { path = "../../router" }
thiserror = "2.0.11"
tokenizers.workspace = true
tokio = "1.43.0"
tokio = { version = "1.43.0", features = ["process"] }
tokio-stream = "0.1.17"
tracing = "0.1.41"

View File

@ -25,8 +25,9 @@ fn main() {
for path in &llama.link_paths {
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
}
if cfg!(target_os = "linux") {
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
}
let bindings = bindgen::Builder::default()
.clang_args(
llama

View File

@ -1,3 +1,4 @@
transformers==4.49
huggingface-hub==0.28.1
hf-transfer==0.1.9
torch==2.6.0

View File

@ -1,10 +1,5 @@
mod llamacpp {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
}
use crate::llamacpp;
use async_trait::async_trait;
use std::ffi::CString;
use std::mem::replace;

View File

@ -0,0 +1,5 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));

View File

@ -1,13 +1,21 @@
mod backend;
mod llamacpp;
mod quantize;
use quantize::QuantizeType;
use backend::{
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
LlamacppSplitMode,
};
use clap::Parser;
use hf_hub::api::tokio::ApiBuilder;
use hf_hub::{Repo, RepoType};
use std::path::Path;
use text_generation_router::{logging, server, usage_stats};
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokenizers::Tokenizer;
use tokio::process::Command;
use tokio::sync::oneshot::error::RecvError;
use tracing::{error, warn};
@ -25,7 +33,7 @@ struct Args {
/// Path to the GGUF model file for inference.
#[clap(long, env)]
model_gguf: String, // TODO Option() with hf->gguf & quantize
model_gguf: Option<String>,
/// Number of threads to use for generation.
#[clap(long, env)]
@ -53,7 +61,7 @@ struct Args {
/// Use memory mapping for the model.
#[clap(long, env)]
use_mmap: bool,
disable_mmap: bool,
/// Use memory locking to prevent swapping.
#[clap(long, env)]
@ -61,11 +69,11 @@ struct Args {
/// Enable offloading of KQV operations to the GPU.
#[clap(long, env)]
offload_kqv: bool,
disable_offload_kqv: bool,
/// Enable flash attention for faster inference. (EXPERIMENTAL)
#[clap(long, env)]
flash_attention: bool,
disable_flash_attention: bool,
/// Data type used for K cache.
#[clap(default_value = "f16", value_enum, long, env)]
@ -194,35 +202,80 @@ async fn main() -> Result<(), RouterError> {
));
}
// TODO: check if we use the same cache of Server
// check if llamacpp is faster
let tokenizer = {
let token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
let params = FromPretrainedParameters {
revision: args.revision.clone(),
token,
..Default::default()
let api_builder = || {
let mut builder = ApiBuilder::new().with_progress(true);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
if let Ok(token) = std::env::var("HF_TOKEN") {
builder = builder.with_token(token.into());
}
if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
builder = builder.with_user_agent("origin", origin.as_str());
}
builder
};
Tokenizer::from_pretrained(args.model_id.clone(), Some(params))?
let api_repo = api_builder().build()?.repo(Repo::with_revision(
args.model_id.clone(),
RepoType::Model,
args.revision.clone(),
));
let tokenizer_path = api_repo.get("tokenizer.json").await?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
let model_gguf = if let Some(model_gguf) = args.model_gguf {
model_gguf
} else {
let model_gguf = format!("models/{}/model.gguf", args.model_id);
let model_gguf_path = Path::new(&model_gguf);
if !model_gguf_path.exists() {
let tmp_gguf = "models/tmp.gguf";
if let Some(parent) = Path::new(model_gguf_path).parent() {
std::fs::create_dir_all(parent)?;
}
let cache_path = tokenizer_path.parent().unwrap();
for sibling in api_repo.info().await?.siblings {
let _ = api_repo.get(&sibling.rfilename).await?;
}
let status = Command::new("convert_hf_to_gguf.py")
.arg("--outfile")
.arg(tmp_gguf)
.arg(cache_path)
.spawn()?
.wait()
.await?;
if !status.success() {
let exit_code = status.code().unwrap_or(-1);
error!("Failed to generate GGUF, exit code: {}", exit_code);
return Err(RouterError::CommandError(exit_code));
}
quantize::model(tmp_gguf, &model_gguf, QuantizeType::MostlyQ4_0, n_threads)
.map_err(RouterError::QuantizeError)?;
}
model_gguf
};
let (backend, ok, shutdown) = LlamacppBackend::new(
LlamacppConfig {
model_gguf: args.model_gguf,
model_gguf,
n_threads,
n_threads_batch,
n_gpu_layers: args.n_gpu_layers,
split_mode: args.split_mode,
defrag_threshold: args.defrag_threshold,
numa: args.numa,
use_mmap: args.use_mmap,
use_mmap: !args.disable_mmap,
use_mlock: args.use_mlock,
flash_attention: args.flash_attention,
flash_attention: !args.disable_flash_attention,
type_k: args.type_k,
type_v: args.type_v,
offload_kqv: args.offload_kqv,
offload_kqv: !args.disable_offload_kqv,
max_batch_total_tokens,
max_physical_batch_total_tokens,
max_batch_size,
@ -281,4 +334,14 @@ enum RouterError {
WebServer(#[from] server::WebServerError),
#[error("Recv error: {0}")]
RecvError(#[from] RecvError),
#[error("Io error: {0}")]
IoError(#[from] std::io::Error),
#[error("Var error: {0}")]
VarError(#[from] std::env::VarError),
#[error("Quantize error: {0}")]
QuantizeError(String),
#[error("Command error: {0}")]
CommandError(i32),
#[error("HF hub error: {0}")]
HubError(#[from] hf_hub::api::tokio::ApiError),
}

View File

@ -0,0 +1,35 @@
use crate::llamacpp;
use std::ffi::CString;
#[repr(u32)]
#[derive(Debug, Clone, Copy)]
pub enum QuantizeType {
MostlyQ4_0 = 2,
}
pub fn model(
input_path: &str,
output_path: &str,
ftype: QuantizeType,
n_threads: usize,
) -> Result<(), String> {
let c_input_path =
CString::new(input_path).map_err(|e| format!("Failed to convert input path: {}", e))?;
let c_output_path =
CString::new(output_path).map_err(|e| format!("Failed to convert output path: {}", e))?;
let result = unsafe {
let mut params = llamacpp::model_quantize_default_params();
params.nthread = n_threads as _;
params.ftype = ftype as _;
params.quantize_output_tensor = true;
llamacpp::model_quantize(c_input_path.as_ptr(), c_output_path.as_ptr(), &params)
};
if result == 0 {
Ok(())
} else {
Err(format!("Quantization failed, error code: {}", result))
}
}

View File

@ -25,9 +25,12 @@ You will find the best models on [Hugging Face][GGUF].
## Build Docker image
For optimal performance, the Docker image is compiled with native CPU
instructions, thus it's highly recommended to execute the container on
the host used during the build process. Efforts are ongoing to enhance
portability while maintaining high computational efficiency.
instructions by default. As a result, it is strongly recommended to run
the container on the same host architecture used during the build
process. Efforts are ongoing to improve portability across different
systems while preserving high computational efficiency.
To build the Docker image, use the following command:
```bash
docker build \
@ -38,20 +41,24 @@ docker build \
### Build parameters
| Parameter | Description |
| ------------------------------------ | --------------------------------- |
| `--build-arg llamacpp_version=bXXXX` | Specific version of llama.cpp |
| `--build-arg llamacpp_cuda=ON` | Enables CUDA acceleration |
| `--build-arg cuda_arch=ARCH` | Defines target CUDA architecture |
| Parameter (with --build-arg) | Description |
| ----------------------------------------- | -------------------------------- |
| `llamacpp_version=bXXXX` | Specific version of llama.cpp |
| `llamacpp_cuda=ON` | Enables CUDA acceleration |
| `llamacpp_native=OFF` | Disable automatic CPU detection |
| `llamacpp_cpu_arm_arch=ARCH[+FEATURE]...` | Specific ARM CPU and features |
| `cuda_arch=ARCH` | Defines target CUDA architecture |
## Model preparation
Retrieve a GGUF model and store it in a specific directory, for example:
For example, to target Graviton4 when building on another ARM
architecture:
```bash
mkdir -p ~/models
cd ~/models
curl -LOJ "https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwen2.5-3b-instruct-q4_0.gguf?download=true"
docker build \
-t tgi-llamacpp \
--build-arg llamacpp_native=OFF \
--build-arg llamacpp_cpu_arm_arch=armv9-a+i8mm \
https://github.com/huggingface/text-generation-inference.git \
-f Dockerfile_llamacpp
```
## Run Docker image
@ -62,10 +69,9 @@ curl -LOJ "https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwe
docker run \
-p 3000:3000 \
-e "HF_TOKEN=$HF_TOKEN" \
-v "$HOME/models:/models" \
-v "$HOME/models:/app/models" \
tgi-llamacpp \
--model-id "Qwen/Qwen2.5-3B-Instruct" \
--model-gguf "/models/qwen2.5-3b-instruct-q4_0.gguf"
--model-id "Qwen/Qwen2.5-3B-Instruct"
```
### GPU-Accelerated inference
@ -75,13 +81,31 @@ docker run \
--gpus all \
-p 3000:3000 \
-e "HF_TOKEN=$HF_TOKEN" \
-v "$HOME/models:/models" \
-v "$HOME/models:/app/models" \
tgi-llamacpp \
--n-gpu-layers 99
--model-id "Qwen/Qwen2.5-3B-Instruct" \
--model-gguf "/models/qwen2.5-3b-instruct-q4_0.gguf"
--model-id "Qwen/Qwen2.5-3B-Instruct"
```
## Using a custom GGUF
GGUF files are optional as they will be automatically generated at
startup if not already present in the `models` directory. However, if
the default GGUF generation is not suitable for your use case, you can
provide your own GGUF file with `--model-gguf`, for example:
```bash
docker run \
-p 3000:3000 \
-e "HF_TOKEN=$HF_TOKEN" \
-v "$HOME/models:/app/models" \
tgi-llamacpp \
--model-id "Qwen/Qwen2.5-3B-Instruct" \
--model-gguf "models/qwen2.5-3b-instruct-q4_0.gguf"
```
Note that `--model-id` is still required.
## Advanced parameters
A full listing of configurable parameters is available in the `--help`:
@ -101,10 +125,10 @@ The table below summarizes key options:
| `--split-mode` | Split the model across multiple GPUs |
| `--defrag-threshold` | Defragment the KV cache if holes/size > threshold |
| `--numa` | Enable NUMA optimizations |
| `--use-mmap` | Use memory mapping for the model |
| `--disable-mmap` | Disable memory mapping for the model |
| `--use-mlock` | Use memory locking to prevent swapping |
| `--offload-kqv` | Enable offloading of KQV operations to the GPU |
| `--flash-attention` | Enable flash attention for faster inference |
| `--disable-offload-kqv` | Disable offloading of KQV operations to the GPU |
| `--disable-flash-attention` | Disable flash attention |
| `--type-k` | Data type used for K cache |
| `--type-v` | Data type used for V cache |
| `--validation-workers` | Number of tokenizer workers used for payload validation and truncation |