mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Update the llamacpp backend (#3022)
* Build faster Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Make --model-gguf optional Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Bump llama.cpp Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Enable mmap, offload_kqv & flash_attention by default Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Update doc Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Better error message Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Update doc Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Update installed packages Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Save gguf in models/MODEL_ID/model.gguf Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Fix build with Mach-O Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Quantize without llama-quantize Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Bump llama.cpp and switch to ggml-org Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Remove make-gguf.sh Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Update Cargo.lock Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Support HF_HUB_USER_AGENT_ORIGIN Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Bump llama.cpp Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Add --build-arg llamacpp_native & llamacpp_cpu_arm_arch Signed-off-by: Adrien Gallouët <angt@huggingface.co> --------- Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
dc5f05f8e6
commit
094975c3a8
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4754,6 +4754,7 @@ dependencies = [
|
|||||||
"async-trait",
|
"async-trait",
|
||||||
"bindgen 0.71.1",
|
"bindgen 0.71.1",
|
||||||
"clap 4.5.30",
|
"clap 4.5.30",
|
||||||
|
"hf-hub 0.3.2",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
|
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
|
||||||
|
|
||||||
ARG llamacpp_version=b4651
|
ARG llamacpp_version=b4827
|
||||||
ARG llamacpp_cuda=OFF
|
ARG llamacpp_cuda=OFF
|
||||||
|
ARG llamacpp_native=ON
|
||||||
|
ARG llamacpp_cpu_arm_arch=native
|
||||||
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
|
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
|
||||||
|
|
||||||
WORKDIR /opt/src
|
WORKDIR /opt/src
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
RUN apt update && apt install -y \
|
RUN apt update && apt upgrade -y && apt install -y \
|
||||||
clang \
|
clang \
|
||||||
cmake \
|
cmake \
|
||||||
curl \
|
curl \
|
||||||
@ -17,9 +19,10 @@ RUN apt update && apt install -y \
|
|||||||
pkg-config \
|
pkg-config \
|
||||||
tar
|
tar
|
||||||
|
|
||||||
ADD https://github.com/ggerganov/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
|
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
|
||||||
RUN tar -xzf ${llamacpp_version}.tar.gz \
|
RUN mkdir -p llama.cpp \
|
||||||
&& cd llama.cpp-${llamacpp_version} \
|
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
|
||||||
|
&& cd llama.cpp \
|
||||||
&& cmake -B build \
|
&& cmake -B build \
|
||||||
-DCMAKE_INSTALL_PREFIX=/usr \
|
-DCMAKE_INSTALL_PREFIX=/usr \
|
||||||
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
|
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
|
||||||
@ -27,6 +30,8 @@ RUN tar -xzf ${llamacpp_version}.tar.gz \
|
|||||||
-DCMAKE_CXX_COMPILER=clang++ \
|
-DCMAKE_CXX_COMPILER=clang++ \
|
||||||
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
|
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
|
||||||
-DGGML_CUDA=${llamacpp_cuda} \
|
-DGGML_CUDA=${llamacpp_cuda} \
|
||||||
|
-DGGML_NATIVE=${llamacpp_native} \
|
||||||
|
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
|
||||||
-DLLAMA_BUILD_COMMON=OFF \
|
-DLLAMA_BUILD_COMMON=OFF \
|
||||||
-DLLAMA_BUILD_TESTS=OFF \
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||||
@ -48,16 +53,18 @@ FROM deps AS builder
|
|||||||
COPY --from=planner /app/recipe.json recipe.json
|
COPY --from=planner /app/recipe.json recipe.json
|
||||||
RUN cargo chef cook \
|
RUN cargo chef cook \
|
||||||
--recipe-path recipe.json \
|
--recipe-path recipe.json \
|
||||||
--profile release-opt \
|
--profile release \
|
||||||
--package text-generation-router-llamacpp
|
--package text-generation-router-llamacpp
|
||||||
COPY . .
|
COPY . .
|
||||||
RUN cargo build \
|
RUN cargo build \
|
||||||
--profile release-opt \
|
--profile release \
|
||||||
--package text-generation-router-llamacpp --frozen
|
--package text-generation-router-llamacpp --frozen
|
||||||
|
|
||||||
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
|
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
RUN apt update && apt install -y \
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
RUN apt update && apt upgrade -y && apt install -y \
|
||||||
python3-venv \
|
python3-venv \
|
||||||
python3-pip
|
python3-pip
|
||||||
|
|
||||||
@ -65,11 +72,16 @@ RUN python3 -m venv /venv
|
|||||||
ENV PATH="/venv/bin:$PATH"
|
ENV PATH="/venv/bin:$PATH"
|
||||||
|
|
||||||
COPY backends/llamacpp/requirements.txt requirements.txt
|
COPY backends/llamacpp/requirements.txt requirements.txt
|
||||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
|
||||||
|
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
|
||||||
|
|
||||||
|
RUN pip3 install --no-cache-dir \
|
||||||
|
-r requirements.txt \
|
||||||
|
-e gguf-py
|
||||||
|
|
||||||
COPY --from=builder /usr/lib/libllama.so /usr/lib/
|
COPY --from=builder /usr/lib/libllama.so /usr/lib/
|
||||||
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
|
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
|
||||||
COPY --from=builder /app/target/release-opt/text-generation-router-llamacpp /usr/bin/
|
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
|
||||||
|
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
||||||
|
|
||||||
|
@ -12,10 +12,11 @@ pkg-config = "0.3.31"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.85"
|
async-trait = "0.1.85"
|
||||||
clap = "4.5.27"
|
clap = "4.5.27"
|
||||||
|
hf-hub.workspace = true
|
||||||
num_cpus = "1.16.0"
|
num_cpus = "1.16.0"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
thiserror = "2.0.11"
|
thiserror = "2.0.11"
|
||||||
tokenizers.workspace = true
|
tokenizers.workspace = true
|
||||||
tokio = "1.43.0"
|
tokio = { version = "1.43.0", features = ["process"] }
|
||||||
tokio-stream = "0.1.17"
|
tokio-stream = "0.1.17"
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
|
@ -25,8 +25,9 @@ fn main() {
|
|||||||
for path in &llama.link_paths {
|
for path in &llama.link_paths {
|
||||||
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
|
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
|
||||||
}
|
}
|
||||||
|
if cfg!(target_os = "linux") {
|
||||||
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
|
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
|
||||||
|
}
|
||||||
let bindings = bindgen::Builder::default()
|
let bindings = bindgen::Builder::default()
|
||||||
.clang_args(
|
.clang_args(
|
||||||
llama
|
llama
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
transformers==4.49
|
transformers==4.49
|
||||||
huggingface-hub==0.28.1
|
huggingface-hub==0.28.1
|
||||||
hf-transfer==0.1.9
|
hf-transfer==0.1.9
|
||||||
|
torch==2.6.0
|
||||||
|
@ -1,10 +1,5 @@
|
|||||||
mod llamacpp {
|
use crate::llamacpp;
|
||||||
#![allow(non_upper_case_globals)]
|
|
||||||
#![allow(non_camel_case_types)]
|
|
||||||
#![allow(non_snake_case)]
|
|
||||||
#![allow(dead_code)]
|
|
||||||
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
|
|
||||||
}
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
use std::mem::replace;
|
use std::mem::replace;
|
||||||
|
5
backends/llamacpp/src/llamacpp.rs
Normal file
5
backends/llamacpp/src/llamacpp.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#![allow(non_upper_case_globals)]
|
||||||
|
#![allow(non_camel_case_types)]
|
||||||
|
#![allow(non_snake_case)]
|
||||||
|
#![allow(dead_code)]
|
||||||
|
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
|
@ -1,13 +1,21 @@
|
|||||||
mod backend;
|
mod backend;
|
||||||
|
mod llamacpp;
|
||||||
|
mod quantize;
|
||||||
|
|
||||||
|
use quantize::QuantizeType;
|
||||||
|
|
||||||
use backend::{
|
use backend::{
|
||||||
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
|
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
|
||||||
LlamacppSplitMode,
|
LlamacppSplitMode,
|
||||||
};
|
};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use hf_hub::api::tokio::ApiBuilder;
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
|
use std::path::Path;
|
||||||
use text_generation_router::{logging, server, usage_stats};
|
use text_generation_router::{logging, server, usage_stats};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::process::Command;
|
||||||
use tokio::sync::oneshot::error::RecvError;
|
use tokio::sync::oneshot::error::RecvError;
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
|
|
||||||
@ -25,7 +33,7 @@ struct Args {
|
|||||||
|
|
||||||
/// Path to the GGUF model file for inference.
|
/// Path to the GGUF model file for inference.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
model_gguf: String, // TODO Option() with hf->gguf & quantize
|
model_gguf: Option<String>,
|
||||||
|
|
||||||
/// Number of threads to use for generation.
|
/// Number of threads to use for generation.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -53,7 +61,7 @@ struct Args {
|
|||||||
|
|
||||||
/// Use memory mapping for the model.
|
/// Use memory mapping for the model.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
use_mmap: bool,
|
disable_mmap: bool,
|
||||||
|
|
||||||
/// Use memory locking to prevent swapping.
|
/// Use memory locking to prevent swapping.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -61,11 +69,11 @@ struct Args {
|
|||||||
|
|
||||||
/// Enable offloading of KQV operations to the GPU.
|
/// Enable offloading of KQV operations to the GPU.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
offload_kqv: bool,
|
disable_offload_kqv: bool,
|
||||||
|
|
||||||
/// Enable flash attention for faster inference. (EXPERIMENTAL)
|
/// Enable flash attention for faster inference. (EXPERIMENTAL)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
flash_attention: bool,
|
disable_flash_attention: bool,
|
||||||
|
|
||||||
/// Data type used for K cache.
|
/// Data type used for K cache.
|
||||||
#[clap(default_value = "f16", value_enum, long, env)]
|
#[clap(default_value = "f16", value_enum, long, env)]
|
||||||
@ -194,35 +202,80 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check if we use the same cache of Server
|
let api_builder = || {
|
||||||
// check if llamacpp is faster
|
let mut builder = ApiBuilder::new().with_progress(true);
|
||||||
let tokenizer = {
|
|
||||||
let token = std::env::var("HF_TOKEN")
|
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
.ok();
|
}
|
||||||
let params = FromPretrainedParameters {
|
if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
revision: args.revision.clone(),
|
builder = builder.with_token(token.into());
|
||||||
token,
|
}
|
||||||
..Default::default()
|
if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
|
||||||
|
builder = builder.with_user_agent("origin", origin.as_str());
|
||||||
|
}
|
||||||
|
builder
|
||||||
};
|
};
|
||||||
Tokenizer::from_pretrained(args.model_id.clone(), Some(params))?
|
let api_repo = api_builder().build()?.repo(Repo::with_revision(
|
||||||
|
args.model_id.clone(),
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_path = api_repo.get("tokenizer.json").await?;
|
||||||
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
|
||||||
|
|
||||||
|
let model_gguf = if let Some(model_gguf) = args.model_gguf {
|
||||||
|
model_gguf
|
||||||
|
} else {
|
||||||
|
let model_gguf = format!("models/{}/model.gguf", args.model_id);
|
||||||
|
let model_gguf_path = Path::new(&model_gguf);
|
||||||
|
|
||||||
|
if !model_gguf_path.exists() {
|
||||||
|
let tmp_gguf = "models/tmp.gguf";
|
||||||
|
|
||||||
|
if let Some(parent) = Path::new(model_gguf_path).parent() {
|
||||||
|
std::fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
let cache_path = tokenizer_path.parent().unwrap();
|
||||||
|
|
||||||
|
for sibling in api_repo.info().await?.siblings {
|
||||||
|
let _ = api_repo.get(&sibling.rfilename).await?;
|
||||||
|
}
|
||||||
|
let status = Command::new("convert_hf_to_gguf.py")
|
||||||
|
.arg("--outfile")
|
||||||
|
.arg(tmp_gguf)
|
||||||
|
.arg(cache_path)
|
||||||
|
.spawn()?
|
||||||
|
.wait()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !status.success() {
|
||||||
|
let exit_code = status.code().unwrap_or(-1);
|
||||||
|
error!("Failed to generate GGUF, exit code: {}", exit_code);
|
||||||
|
return Err(RouterError::CommandError(exit_code));
|
||||||
|
}
|
||||||
|
quantize::model(tmp_gguf, &model_gguf, QuantizeType::MostlyQ4_0, n_threads)
|
||||||
|
.map_err(RouterError::QuantizeError)?;
|
||||||
|
}
|
||||||
|
model_gguf
|
||||||
};
|
};
|
||||||
|
|
||||||
let (backend, ok, shutdown) = LlamacppBackend::new(
|
let (backend, ok, shutdown) = LlamacppBackend::new(
|
||||||
LlamacppConfig {
|
LlamacppConfig {
|
||||||
model_gguf: args.model_gguf,
|
model_gguf,
|
||||||
n_threads,
|
n_threads,
|
||||||
n_threads_batch,
|
n_threads_batch,
|
||||||
n_gpu_layers: args.n_gpu_layers,
|
n_gpu_layers: args.n_gpu_layers,
|
||||||
split_mode: args.split_mode,
|
split_mode: args.split_mode,
|
||||||
defrag_threshold: args.defrag_threshold,
|
defrag_threshold: args.defrag_threshold,
|
||||||
numa: args.numa,
|
numa: args.numa,
|
||||||
use_mmap: args.use_mmap,
|
use_mmap: !args.disable_mmap,
|
||||||
use_mlock: args.use_mlock,
|
use_mlock: args.use_mlock,
|
||||||
flash_attention: args.flash_attention,
|
flash_attention: !args.disable_flash_attention,
|
||||||
type_k: args.type_k,
|
type_k: args.type_k,
|
||||||
type_v: args.type_v,
|
type_v: args.type_v,
|
||||||
offload_kqv: args.offload_kqv,
|
offload_kqv: !args.disable_offload_kqv,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_physical_batch_total_tokens,
|
max_physical_batch_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
@ -281,4 +334,14 @@ enum RouterError {
|
|||||||
WebServer(#[from] server::WebServerError),
|
WebServer(#[from] server::WebServerError),
|
||||||
#[error("Recv error: {0}")]
|
#[error("Recv error: {0}")]
|
||||||
RecvError(#[from] RecvError),
|
RecvError(#[from] RecvError),
|
||||||
|
#[error("Io error: {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
#[error("Var error: {0}")]
|
||||||
|
VarError(#[from] std::env::VarError),
|
||||||
|
#[error("Quantize error: {0}")]
|
||||||
|
QuantizeError(String),
|
||||||
|
#[error("Command error: {0}")]
|
||||||
|
CommandError(i32),
|
||||||
|
#[error("HF hub error: {0}")]
|
||||||
|
HubError(#[from] hf_hub::api::tokio::ApiError),
|
||||||
}
|
}
|
||||||
|
35
backends/llamacpp/src/quantize.rs
Normal file
35
backends/llamacpp/src/quantize.rs
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
use crate::llamacpp;
|
||||||
|
|
||||||
|
use std::ffi::CString;
|
||||||
|
|
||||||
|
#[repr(u32)]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum QuantizeType {
|
||||||
|
MostlyQ4_0 = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn model(
|
||||||
|
input_path: &str,
|
||||||
|
output_path: &str,
|
||||||
|
ftype: QuantizeType,
|
||||||
|
n_threads: usize,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let c_input_path =
|
||||||
|
CString::new(input_path).map_err(|e| format!("Failed to convert input path: {}", e))?;
|
||||||
|
|
||||||
|
let c_output_path =
|
||||||
|
CString::new(output_path).map_err(|e| format!("Failed to convert output path: {}", e))?;
|
||||||
|
|
||||||
|
let result = unsafe {
|
||||||
|
let mut params = llamacpp::model_quantize_default_params();
|
||||||
|
params.nthread = n_threads as _;
|
||||||
|
params.ftype = ftype as _;
|
||||||
|
params.quantize_output_tensor = true;
|
||||||
|
llamacpp::model_quantize(c_input_path.as_ptr(), c_output_path.as_ptr(), ¶ms)
|
||||||
|
};
|
||||||
|
if result == 0 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Quantization failed, error code: {}", result))
|
||||||
|
}
|
||||||
|
}
|
@ -25,9 +25,12 @@ You will find the best models on [Hugging Face][GGUF].
|
|||||||
## Build Docker image
|
## Build Docker image
|
||||||
|
|
||||||
For optimal performance, the Docker image is compiled with native CPU
|
For optimal performance, the Docker image is compiled with native CPU
|
||||||
instructions, thus it's highly recommended to execute the container on
|
instructions by default. As a result, it is strongly recommended to run
|
||||||
the host used during the build process. Efforts are ongoing to enhance
|
the container on the same host architecture used during the build
|
||||||
portability while maintaining high computational efficiency.
|
process. Efforts are ongoing to improve portability across different
|
||||||
|
systems while preserving high computational efficiency.
|
||||||
|
|
||||||
|
To build the Docker image, use the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build \
|
docker build \
|
||||||
@ -38,20 +41,24 @@ docker build \
|
|||||||
|
|
||||||
### Build parameters
|
### Build parameters
|
||||||
|
|
||||||
| Parameter | Description |
|
| Parameter (with --build-arg) | Description |
|
||||||
| ------------------------------------ | --------------------------------- |
|
| ----------------------------------------- | -------------------------------- |
|
||||||
| `--build-arg llamacpp_version=bXXXX` | Specific version of llama.cpp |
|
| `llamacpp_version=bXXXX` | Specific version of llama.cpp |
|
||||||
| `--build-arg llamacpp_cuda=ON` | Enables CUDA acceleration |
|
| `llamacpp_cuda=ON` | Enables CUDA acceleration |
|
||||||
| `--build-arg cuda_arch=ARCH` | Defines target CUDA architecture |
|
| `llamacpp_native=OFF` | Disable automatic CPU detection |
|
||||||
|
| `llamacpp_cpu_arm_arch=ARCH[+FEATURE]...` | Specific ARM CPU and features |
|
||||||
|
| `cuda_arch=ARCH` | Defines target CUDA architecture |
|
||||||
|
|
||||||
## Model preparation
|
For example, to target Graviton4 when building on another ARM
|
||||||
|
architecture:
|
||||||
Retrieve a GGUF model and store it in a specific directory, for example:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
mkdir -p ~/models
|
docker build \
|
||||||
cd ~/models
|
-t tgi-llamacpp \
|
||||||
curl -LOJ "https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwen2.5-3b-instruct-q4_0.gguf?download=true"
|
--build-arg llamacpp_native=OFF \
|
||||||
|
--build-arg llamacpp_cpu_arm_arch=armv9-a+i8mm \
|
||||||
|
https://github.com/huggingface/text-generation-inference.git \
|
||||||
|
-f Dockerfile_llamacpp
|
||||||
```
|
```
|
||||||
|
|
||||||
## Run Docker image
|
## Run Docker image
|
||||||
@ -62,10 +69,9 @@ curl -LOJ "https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwe
|
|||||||
docker run \
|
docker run \
|
||||||
-p 3000:3000 \
|
-p 3000:3000 \
|
||||||
-e "HF_TOKEN=$HF_TOKEN" \
|
-e "HF_TOKEN=$HF_TOKEN" \
|
||||||
-v "$HOME/models:/models" \
|
-v "$HOME/models:/app/models" \
|
||||||
tgi-llamacpp \
|
tgi-llamacpp \
|
||||||
--model-id "Qwen/Qwen2.5-3B-Instruct" \
|
--model-id "Qwen/Qwen2.5-3B-Instruct"
|
||||||
--model-gguf "/models/qwen2.5-3b-instruct-q4_0.gguf"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### GPU-Accelerated inference
|
### GPU-Accelerated inference
|
||||||
@ -75,13 +81,31 @@ docker run \
|
|||||||
--gpus all \
|
--gpus all \
|
||||||
-p 3000:3000 \
|
-p 3000:3000 \
|
||||||
-e "HF_TOKEN=$HF_TOKEN" \
|
-e "HF_TOKEN=$HF_TOKEN" \
|
||||||
-v "$HOME/models:/models" \
|
-v "$HOME/models:/app/models" \
|
||||||
tgi-llamacpp \
|
tgi-llamacpp \
|
||||||
--n-gpu-layers 99
|
--n-gpu-layers 99
|
||||||
--model-id "Qwen/Qwen2.5-3B-Instruct" \
|
--model-id "Qwen/Qwen2.5-3B-Instruct"
|
||||||
--model-gguf "/models/qwen2.5-3b-instruct-q4_0.gguf"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Using a custom GGUF
|
||||||
|
|
||||||
|
GGUF files are optional as they will be automatically generated at
|
||||||
|
startup if not already present in the `models` directory. However, if
|
||||||
|
the default GGUF generation is not suitable for your use case, you can
|
||||||
|
provide your own GGUF file with `--model-gguf`, for example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run \
|
||||||
|
-p 3000:3000 \
|
||||||
|
-e "HF_TOKEN=$HF_TOKEN" \
|
||||||
|
-v "$HOME/models:/app/models" \
|
||||||
|
tgi-llamacpp \
|
||||||
|
--model-id "Qwen/Qwen2.5-3B-Instruct" \
|
||||||
|
--model-gguf "models/qwen2.5-3b-instruct-q4_0.gguf"
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that `--model-id` is still required.
|
||||||
|
|
||||||
## Advanced parameters
|
## Advanced parameters
|
||||||
|
|
||||||
A full listing of configurable parameters is available in the `--help`:
|
A full listing of configurable parameters is available in the `--help`:
|
||||||
@ -101,10 +125,10 @@ The table below summarizes key options:
|
|||||||
| `--split-mode` | Split the model across multiple GPUs |
|
| `--split-mode` | Split the model across multiple GPUs |
|
||||||
| `--defrag-threshold` | Defragment the KV cache if holes/size > threshold |
|
| `--defrag-threshold` | Defragment the KV cache if holes/size > threshold |
|
||||||
| `--numa` | Enable NUMA optimizations |
|
| `--numa` | Enable NUMA optimizations |
|
||||||
| `--use-mmap` | Use memory mapping for the model |
|
| `--disable-mmap` | Disable memory mapping for the model |
|
||||||
| `--use-mlock` | Use memory locking to prevent swapping |
|
| `--use-mlock` | Use memory locking to prevent swapping |
|
||||||
| `--offload-kqv` | Enable offloading of KQV operations to the GPU |
|
| `--disable-offload-kqv` | Disable offloading of KQV operations to the GPU |
|
||||||
| `--flash-attention` | Enable flash attention for faster inference |
|
| `--disable-flash-attention` | Disable flash attention |
|
||||||
| `--type-k` | Data type used for K cache |
|
| `--type-k` | Data type used for K cache |
|
||||||
| `--type-v` | Data type used for V cache |
|
| `--type-v` | Data type used for V cache |
|
||||||
| `--validation-workers` | Number of tokenizer workers used for payload validation and truncation |
|
| `--validation-workers` | Number of tokenizer workers used for payload validation and truncation |
|
||||||
|
Loading…
Reference in New Issue
Block a user