mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into sliding_window
This commit is contained in:
commit
1d384f8c98
@ -4,3 +4,4 @@ server/transformers
|
||||
server/flash-attention
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
Dockerfile*
|
||||
|
3
.github/workflows/tests.yaml
vendored
3
.github/workflows/tests.yaml
vendored
@ -58,3 +58,6 @@ jobs:
|
||||
- name: Run Rust tests
|
||||
run: |
|
||||
cargo test
|
||||
- name: Run Rust tests with google feature
|
||||
run: |
|
||||
cargo test --features google
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,9 +3,8 @@ target
|
||||
router/tokenizer.json
|
||||
*__pycache__*
|
||||
|
||||
backends/v2/src/client/pb
|
||||
backends/v3/src/client/pb
|
||||
backends/client/src/v2/pb
|
||||
backends/client/src/v3/pb
|
||||
|
||||
# ROCm auto-generated files
|
||||
*.hip
|
||||
|
386
Cargo.lock
generated
386
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,26 +1,26 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"backends/v2",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
"backends/trtllm",
|
||||
"backends/client",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
default-members = [
|
||||
"benchmark",
|
||||
"backends/v2",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
# "backends/trtllm",
|
||||
"backends/client",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
@ -40,7 +40,6 @@ COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
|
24
Dockerfile.nix
Normal file
24
Dockerfile.nix
Normal file
@ -0,0 +1,24 @@
|
||||
# Build the image and get out the docker file:
|
||||
#
|
||||
# docker build -t tgi-nix-builder -f Dockerfile.nix
|
||||
# docker run --log-driver=none tgi-nix-builder | docker load
|
||||
|
||||
FROM nixos/nix:2.18.8 AS builder
|
||||
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
|
||||
RUN nix profile install nixpkgs#cachix
|
||||
RUN cachix use text-generation-inference
|
||||
WORKDIR /root
|
||||
ADD . .
|
||||
RUN nix build .
|
||||
RUN mkdir /tmp/nix-store-closure
|
||||
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
|
||||
|
||||
FROM ubuntu:24.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy /nix/store
|
||||
COPY --from=builder /tmp/nix-store-closure /nix/store
|
||||
COPY --from=builder /root/result /app
|
||||
RUN ldconfig
|
||||
CMD ["ldconfig", "/app/bin/text-generation-launcher"]
|
175
Dockerfile_amd
175
Dockerfile_amd
@ -41,7 +41,7 @@ COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
@ -50,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
curl \
|
||||
git \
|
||||
make \
|
||||
libmsgpack-dev \
|
||||
libssl-dev \
|
||||
llvm-dev \
|
||||
g++ \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
hipcub-dev \
|
||||
rocblas-dev \
|
||||
hiprand-dev \
|
||||
hipfft-dev \
|
||||
rocrand-dev \
|
||||
miopen-hip-dev \
|
||||
hipfft-dev \
|
||||
hipcub-dev \
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3.11-dev && \
|
||||
python3.11-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
@ -100,41 +101,132 @@ RUN case ${TARGETPLATFORM} in \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja --no-cache-dir
|
||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
pip install .
|
||||
RUN conda install mkl=2021
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
ARG COMMON_WORKDIR=/
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
# Install HIPBLASLt
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH="e6da924"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
||||
&& cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
|
||||
FROM scratch AS export_hipblaslt
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
||||
|
||||
# RCCL build stages
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH="rocm-6.2.0"
|
||||
RUN git clone https://github.com/ROCm/rccl \
|
||||
&& cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
FROM scratch AS export_rccl
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
||||
|
||||
# Triton build stages
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
||||
&& cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_triton
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
||||
|
||||
# # AMD-SMI build stages
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
FROM scratch AS export_amdsmi
|
||||
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
||||
|
||||
|
||||
FROM base as build_pytorch
|
||||
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
ARG BUILD_CAFFE2="0" \
|
||||
BUILD_CAFFE2_OPS="0" \
|
||||
USE_CUDA="0" \
|
||||
USE_ROCM="1" \
|
||||
BUILD_TEST="0" \
|
||||
USE_FBGEMM="0" \
|
||||
USE_NNPACK="0" \
|
||||
USE_QNNPACK="0" \
|
||||
USE_XNNPACK="0" \
|
||||
USE_FLASH_ATTENTION="1" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
# A commit to fix the output scaling factor issue in _scaled_mm
|
||||
# Not yet in 2.5.0-rc1
|
||||
ARG PYTORCH_BRANCH="cedc116"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
RUN git clone ${PYTORCH_REPO} pytorch \
|
||||
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt --no-cache-dir \
|
||||
&& python tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch as export_pytorch
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
FROM base AS install_deps
|
||||
|
||||
FROM base AS kernel-builder
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Install hipblaslt
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
# RCCL needs to be installed twice
|
||||
&& dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y triton \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y amdsmi \
|
||||
&& pip install /install/*.whl;
|
||||
|
||||
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y torch torchvision \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
FROM install_deps AS kernel-builder
|
||||
|
||||
# # Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
@ -174,7 +266,7 @@ COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM base AS base-copy
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
@ -224,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
|
75
backends/v2/Cargo.toml
Normal file
75
backends/v2/Cargo.toml
Normal file
@ -0,0 +1,75 @@
|
||||
[package]
|
||||
name = "text-generation-router-v2"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router-v2"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
19
backends/v2/build.rs
Normal file
19
backends/v2/build.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
506
backends/v2/src/backend.rs
Normal file
506
backends/v2/src/backend.rs
Normal file
@ -0,0 +1,506 @@
|
||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v2_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
257
backends/v2/src/client/grpc_client.rs
Normal file
257
backends/v2/src/client/grpc_client.rs
Normal file
@ -0,0 +1,257 @@
|
||||
/// Single shard Client
|
||||
use crate::client::pb;
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v2::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
68
backends/v2/src/client/mod.rs
Normal file
68
backends/v2/src/client/mod.rs
Normal file
@ -0,0 +1,68 @@
|
||||
//! Text Generation gRPC client library
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod grpc_client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use grpc_client::Client;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
let err = Self::Generation(err.message().to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
let err = Self::Connection(err.to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
252
backends/v2/src/client/sharded_client.rs
Normal file
252
backends/v2/src/client/sharded_client.rs
Normal file
@ -0,0 +1,252 @@
|
||||
/// Multi shard Client
|
||||
use crate::client::{ClientError, Result};
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::InfoResponse;
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await?;
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||
Ok(results.into_iter().flatten().min())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
141
backends/v2/src/lib.rs
Normal file
141
backends/v2/src/lib.rs
Normal file
@ -0,0 +1,141 @@
|
||||
mod backend;
|
||||
mod client;
|
||||
mod queue;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV2;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn connect_backend(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
master_shard_uds_path: String,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens
|
||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V2Error::Connection)?;
|
||||
|
||||
// server is running on v2
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V2Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V2Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V2Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
}
|
212
backends/v2/src/main.rs
Normal file
212
backends/v2/src/main.rs
Normal file
@ -0,0 +1,212 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use text_generation_router_v2::{connect_backend, V2Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
api_key,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
use utoipa::OpenApi;
|
||||
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
};
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V2Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
@ -1,14 +1,14 @@
|
||||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v2::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use text_generation_client::ChunksToString;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
@ -218,7 +218,7 @@ impl State {
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
@ -404,6 +404,7 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
@ -415,7 +416,9 @@ mod tests {
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: ValidParameters {
|
@ -100,6 +100,7 @@ pub async fn connect_backend(
|
||||
.map_err(V3Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
|
@ -16,7 +16,6 @@ path = "src/main.rs"
|
||||
[dependencies]
|
||||
average = "0.14"
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
crossterm = "0.28.1"
|
||||
float-ord = "0.3.2"
|
||||
serde = {version = "1.0.188", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
@ -25,7 +24,7 @@ text-generation-client = { path = "../backends/client" }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||
ratatui = { version = "0.28.1", default-features = false, features = ["crossterm"] }
|
||||
ratatui = "0.28.1"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
hf-hub = { workspace = true }
|
||||
|
@ -7,7 +7,7 @@
|
||||
</div>
|
||||
|
||||
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
||||
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
|
||||
and powered by [Ratatui](https://github.com/ratatui/ratatui).
|
||||
|
||||
## Install
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||
use crate::generation::{Decode, Message, Prefill};
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
||||
use crossterm::event;
|
||||
use ratatui::crossterm::event;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
|
||||
|
@ -6,8 +6,8 @@ mod utils;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::event::Event;
|
||||
use crossterm::ExecutableCommand;
|
||||
use ratatui::backend::CrosstermBackend;
|
||||
use ratatui::crossterm::ExecutableCommand;
|
||||
use ratatui::Terminal;
|
||||
use std::io;
|
||||
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||
@ -50,9 +50,9 @@ pub async fn run(
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
crossterm::terminal::enable_raw_mode()?;
|
||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||
ratatui::crossterm::terminal::enable_raw_mode()?;
|
||||
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
|
||||
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
|
||||
|
||||
// Initialize terminal
|
||||
let mut terminal = {
|
||||
@ -128,9 +128,9 @@ pub async fn run(
|
||||
let _ = shutdown_guard_receiver.recv().await;
|
||||
|
||||
// Revert terminal to original view
|
||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
||||
crossterm::terminal::disable_raw_mode()?;
|
||||
io::stdout().execute(crossterm::cursor::Show)?;
|
||||
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
|
||||
ratatui::crossterm::terminal::disable_raw_mode()?;
|
||||
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
|
||||
|
||||
let parameters_table = table::parameters_table(
|
||||
tokenizer_name,
|
||||
|
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
||||
function: dict
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
type: str
|
||||
text: Optional[str] = None
|
||||
image_url: Any = None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str] = None
|
||||
content: Optional[Union[str, List[Chunk]]] = None
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "2.3.1-dev0"
|
||||
"version": "2.3.2-dev0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -10,7 +10,7 @@ This diagram shows well there are these separate components:
|
||||
|
||||
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
|
||||
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
|
||||
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||
|
||||
The router and the model server can be two different machines, they do not need to be deployed together.
|
||||
|
||||
|
@ -36,7 +36,13 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
|
||||
To specify model revision, use `adapter_id@revision`, as follows:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||
```
|
||||
|
||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
||||
@ -72,6 +78,22 @@ curl 127.0.0.1:3000/generate \
|
||||
}'
|
||||
```
|
||||
|
||||
If you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:
|
||||
|
||||
```json
|
||||
curl 127.0.0.1:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "Hello who are you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "myadapter"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||
|
||||
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
|
||||
|
||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||
|
||||
## Custom PagedAttention
|
||||
|
||||
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||
|
||||
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||
|
||||
## Unsupported features
|
||||
|
||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-xpu \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-cpu \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,10 +11,19 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you want to serve gated or private models, which provide
|
||||
controlled access to sensitive or proprietary content, refer to
|
||||
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
|
||||
for detailed instructions.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Supported hardware
|
||||
|
||||
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||
|
@ -89,6 +89,15 @@ Options:
|
||||
[env: DTYPE=]
|
||||
[possible values: float16, bfloat16]
|
||||
|
||||
```
|
||||
## KV_CACHE_DTYPE
|
||||
```shell
|
||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
||||
|
||||
[env: KV_CACHE_DTYPE=]
|
||||
[possible values: fp8_e5m2]
|
||||
|
||||
```
|
||||
## TRUST_REMOTE_CODE
|
||||
```shell
|
||||
|
@ -20,6 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
|
||||
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
|
||||
- [Phi](https://huggingface.co/microsoft/phi-1_5)
|
||||
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
||||
@ -34,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
||||
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
|
||||
|
||||
|
||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||
|
32
flake.lock
32
flake.lock
@ -497,11 +497,11 @@
|
||||
"systems": "systems_7"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"lastModified": 1726560853,
|
||||
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -718,11 +718,11 @@
|
||||
},
|
||||
"nixpkgs_6": {
|
||||
"locked": {
|
||||
"lastModified": 1724915739,
|
||||
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
|
||||
"lastModified": 1727675176,
|
||||
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
|
||||
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -853,11 +853,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1726626348,
|
||||
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
|
||||
"lastModified": 1727836133,
|
||||
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
|
||||
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -978,16 +978,16 @@
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1726743157,
|
||||
"narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=",
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f",
|
||||
"lastModified": 1728029332,
|
||||
"narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "98049f853346ca780b81fee730715c90d33ac2b4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
|
59
flake.nix
59
flake.nix
@ -5,7 +5,7 @@
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
@ -37,6 +37,7 @@
|
||||
overlays = [
|
||||
rust-overlay.overlays.default
|
||||
tgi-nix.overlays.default
|
||||
(import nix/overlay.nix)
|
||||
];
|
||||
};
|
||||
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
||||
@ -132,53 +133,17 @@
|
||||
pre-commit
|
||||
ruff
|
||||
]);
|
||||
|
||||
};
|
||||
|
||||
impure = mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
extensions = [
|
||||
"rust-analyzer"
|
||||
"rust-src"
|
||||
];
|
||||
})
|
||||
protobuf
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
docker
|
||||
pip
|
||||
ipdb
|
||||
click
|
||||
pyright
|
||||
pytest
|
||||
pytest-asyncio
|
||||
redocly
|
||||
ruff
|
||||
syrupy
|
||||
]);
|
||||
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
||||
|
||||
inputsFrom = [ server ];
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
( cd server ; python -m pip install --no-dependencies -e . )
|
||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||
'';
|
||||
postShellHook = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
export PATH=$PATH:~/.cargo/bin
|
||||
'';
|
||||
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
||||
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
||||
};
|
||||
};
|
||||
|
||||
packages.default = pkgs.writeShellApplication {
|
||||
packages = rec {
|
||||
default = pkgs.writeShellApplication {
|
||||
name = "text-generation-inference";
|
||||
runtimeInputs = [
|
||||
server
|
||||
@ -188,6 +153,16 @@
|
||||
${launcher}/bin/text-generation-launcher "$@"
|
||||
'';
|
||||
};
|
||||
|
||||
dockerImage = pkgs.callPackage nix/docker.nix {
|
||||
text-generation-inference = default;
|
||||
};
|
||||
|
||||
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
|
||||
text-generation-inference = default;
|
||||
stream = true;
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -336,6 +336,7 @@ def launcher(event_loop):
|
||||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
@ -375,6 +376,9 @@ def launcher(event_loop):
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if kv_cache_dtype is not None:
|
||||
args.append("--kv-cache-dtype")
|
||||
args.append(kv_cache_dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
@ -434,6 +438,7 @@ def launcher(event_loop):
|
||||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
@ -456,6 +461,9 @@ def launcher(event_loop):
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if kv_cache_dtype is not None:
|
||||
args.append("--kv-cache-dtype")
|
||||
args.append(kv_cache_dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
@ -589,7 +597,6 @@ def generate_multi():
|
||||
max_new_tokens: int,
|
||||
seed: Optional[int] = None,
|
||||
) -> List[Response]:
|
||||
|
||||
import numpy as np
|
||||
|
||||
arange = np.arange(len(prompts))
|
||||
|
@ -24,13 +24,13 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -2.03125,
|
||||
"logprob": -2.109375,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -1.8671875,
|
||||
"logprob": -1.90625,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
@ -42,48 +42,48 @@
|
||||
},
|
||||
{
|
||||
"id": 2121,
|
||||
"logprob": -1.8125,
|
||||
"logprob": -1.796875,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -0.24121094,
|
||||
"logprob": -0.24511719,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -0.100097656,
|
||||
"logprob": -0.09326172,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 603,
|
||||
"logprob": -0.9453125,
|
||||
"logprob": -0.95703125,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 476,
|
||||
"logprob": -1.703125,
|
||||
"id": 1671,
|
||||
"logprob": -1.5859375,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 4551,
|
||||
"logprob": -2.453125,
|
||||
"id": 577,
|
||||
"logprob": -0.39257812,
|
||||
"special": false,
|
||||
"text": " document"
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 674,
|
||||
"logprob": -0.796875,
|
||||
"id": 3853,
|
||||
"logprob": -1.25,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " form\n\nThe test request form is a document that"
|
||||
"generated_text": " form\n\nThe test request form is used to request"
|
||||
}
|
||||
|
@ -11,12 +11,12 @@
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.640625,
|
||||
"logprob": -9.6484375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.375,
|
||||
"logprob": -10.3671875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
@ -24,19 +24,19 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -0.2824707,
|
||||
"logprob": -0.28271484,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.19030762,
|
||||
"logprob": -0.18493652,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 16819,
|
||||
"logprob": -1.4892578,
|
||||
"logprob": -1.4804688,
|
||||
"special": false,
|
||||
"text": " detection"
|
||||
},
|
||||
@ -46,44 +46,44 @@
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.0195312,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 8566,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " presence"
|
||||
},
|
||||
{
|
||||
"id": 689,
|
||||
"logprob": -0.16491699,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 14862,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " absence"
|
||||
},
|
||||
{
|
||||
"id": 576,
|
||||
"logprob": -0.9946289,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 671,
|
||||
"logprob": -0.5263672,
|
||||
"logprob": -2.1738281,
|
||||
"special": false,
|
||||
"text": " an"
|
||||
},
|
||||
{
|
||||
"id": 24646,
|
||||
"logprob": -3.0449219,
|
||||
"special": false,
|
||||
"text": " RNA"
|
||||
},
|
||||
{
|
||||
"id": 12369,
|
||||
"logprob": -0.19299316,
|
||||
"special": false,
|
||||
"text": " virus"
|
||||
},
|
||||
{
|
||||
"id": 575,
|
||||
"logprob": -0.10632324,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 6022,
|
||||
"logprob": -0.98095703,
|
||||
"special": false,
|
||||
"text": " patients"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -1.3095703,
|
||||
"special": false,
|
||||
"text": " who"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request for the detection of the presence or absence of an"
|
||||
"generated_text": "Test request for the detection of an RNA virus in patients who"
|
||||
}
|
||||
|
@ -0,0 +1,104 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.079956055,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.2763672,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37548828,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4628906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02885437,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.2565918,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0063438416,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3056641,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.6035156,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 3,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -22.96875,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -10.71875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -2.6992188,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -4.8398438,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 720,
|
||||
"logprob": -0.4411621,
|
||||
"special": false,
|
||||
"text": " \n"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -0.35864258,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 128001,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "<|end_of_text|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning? \n "
|
||||
}
|
@ -0,0 +1,418 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
}
|
||||
]
|
@ -0,0 +1,89 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8808594,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37280273,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.26098633,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017137527,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2695312,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9238281,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48828125,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.34838867,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13940,
|
||||
"logprob": -0.38916016,
|
||||
"special": false,
|
||||
"text": "``"
|
||||
},
|
||||
{
|
||||
"id": 28832,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "`"
|
||||
},
|
||||
{
|
||||
"id": 3371,
|
||||
"logprob": -1.2529297,
|
||||
"special": false,
|
||||
"text": "json"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28751,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "{"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 2287,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 3134,
|
||||
"logprob": -0.640625,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request\n```json\n{\n \"request"
|
||||
}
|
@ -0,0 +1,358 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
}
|
||||
]
|
@ -0,0 +1,109 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7133789,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9296875,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.048919678,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8105469,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.84521484,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.017028809,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0027313232,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0623207e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5361328,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17578125,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.00011539459,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.47436523,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027680397,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": null,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -6.4960938,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -5.1484375,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -4.0351562,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -5.2265625,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 10994,
|
||||
"logprob": -1.1542969,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 29991,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "!"
|
||||
},
|
||||
{
|
||||
"id": 739,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " It"
|
||||
},
|
||||
{
|
||||
"id": 2444,
|
||||
"logprob": -0.42260742,
|
||||
"special": false,
|
||||
"text": " seems"
|
||||
},
|
||||
{
|
||||
"id": 366,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " you"
|
||||
},
|
||||
{
|
||||
"id": 29915,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "'"
|
||||
},
|
||||
{
|
||||
"id": 276,
|
||||
"logprob": -0.9838867,
|
||||
"special": false,
|
||||
"text": "re"
|
||||
},
|
||||
{
|
||||
"id": 3211,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " address"
|
||||
},
|
||||
{
|
||||
"id": 292,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "ing"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.15124512,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
|
||||
}
|
@ -0,0 +1,438 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7133789,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9296875,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.048919678,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8105469,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.84521484,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.017028809,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0028476715,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023971558,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0384789e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17602539,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.000116467476,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.47436523,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027871132,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7128906,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9375,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.05053711,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0058594,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8242188,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.84521484,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.018859863,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.002822876,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0384789e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17126465,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.0001155138,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.47436523,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027036667,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.71484375,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9375,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.049346924,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8242188,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.86328125,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.017196655,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0028438568,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.026558e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17602539,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.00011622906,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.48608398,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027894974,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7192383,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9375,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.050445557,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8242188,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8276367,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.01727295,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0027542114,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0384789e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17126465,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.00011301041,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.48608398,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027894974,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
}
|
||||
]
|
@ -0,0 +1,106 @@
|
||||
[
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a bustling city, a chicken named Cluck",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a world where even chickens could dream big,",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a world where even chickens could dream big,",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a world where even chickens could dream big,",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
}
|
||||
]
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a bustling city, a chicken named Cluck",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727556016,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
}
|
77
integration-tests/models/test_flash_llama_fp8_kv_cache.py
Normal file
77
integration-tests/models/test_flash_llama_fp8_kv_cache.py
Normal file
@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||
with launcher(
|
||||
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
|
||||
await flash_llama_fp8_kv_cache_handle.health(300)
|
||||
return flash_llama_fp8_kv_cache_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
|
||||
response = await flash_llama_fp8_kv_cache.generate(
|
||||
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== " Deep learning is a subset of machine learning that is"
|
||||
)
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache_all_params(
|
||||
flash_llama_fp8_kv_cache, response_snapshot
|
||||
):
|
||||
response = await flash_llama_fp8_kv_cache.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache_load(
|
||||
flash_llama_fp8_kv_cache, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== " Deep learning is a subset of machine learning that is"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"Different messages : {[r.generated_text for r in responses]}"
|
||||
assert responses == response_snapshot
|
60
integration-tests/models/test_flash_mixtral_gptq.py
Normal file
60
integration-tests/models/test_flash_mixtral_gptq.py
Normal file
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_mixtral_gptq_handle(launcher):
|
||||
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
|
||||
await flash_mixtral_gptq_handle.health(300)
|
||||
return flash_mixtral_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||
response = await flash_mixtral_gptq.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
||||
response = await flash_mixtral_gptq.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_gptq_load(
|
||||
flash_mixtral_gptq, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
75
integration-tests/models/test_flash_phi35_moe.py
Normal file
75
integration-tests/models/test_flash_phi35_moe.py
Normal file
@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_phi35_moe_handle(launcher):
|
||||
with launcher(
|
||||
"microsoft/Phi-3.5-MoE-instruct",
|
||||
num_shard=4,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_phi35_moe(flash_phi35_moe_handle):
|
||||
await flash_phi35_moe_handle.health(300)
|
||||
return flash_phi35_moe_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
||||
response = await flash_phi35_moe.generate(
|
||||
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "Gradient descent is a first-order optimization algorithm"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
||||
response = await flash_phi35_moe.generate(
|
||||
"What is gradient descent?\n\n",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "What is gradient descent?\n\nHello! It seems you're addressing a"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert responses[0].details.generated_tokens == 10
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== "Gradient descent is a first-order optimization algorithm"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
105
integration-tests/models/test_mllama.py
Normal file
105
integration-tests/models/test_mllama.py
Normal file
@ -0,0 +1,105 @@
|
||||
import pytest
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mllama_handle(launcher):
|
||||
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def mllama(mllama_handle):
|
||||
await mllama_handle.health(300)
|
||||
return mllama_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_simpl(mllama, response_snapshot):
|
||||
# chicken = get_chicken()
|
||||
response = await mllama.chat(
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you tell me a very short story based on the image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert response.usage == {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60,
|
||||
}
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "In a bustling city, a chicken named Cluck"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
||||
futures = [
|
||||
mllama.chat(
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you tell me a very short story based on the image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
responses = await asyncio.gather(*futures)
|
||||
|
||||
generated_texts = [response.choices[0].message.content for response in responses]
|
||||
|
||||
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
@ -4,7 +4,9 @@ import pytest
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_tools_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
num_shard=2,
|
||||
disable_grammar_support=False,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
@ -208,7 +210,7 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 48
|
||||
assert count == 28
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -12,11 +12,13 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
|
||||
hf-hub = "0.3.2"
|
||||
nix = { version = "0.28.0", features = ["signal"] }
|
||||
once_cell = "1.19.0"
|
||||
pyo3 = { workspace = true }
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.59"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
regex = "1.11.0"
|
||||
|
||||
[dev-dependencies]
|
||||
float_eq = "1.0.1"
|
||||
|
21
launcher/src/gpu.rs
Normal file
21
launcher/src/gpu.rs
Normal file
@ -0,0 +1,21 @@
|
||||
pub fn get_cuda_capability() -> Option<(usize, usize)> {
|
||||
use pyo3::prelude::*;
|
||||
|
||||
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
|
||||
let torch = py.import_bound("torch.cuda")?;
|
||||
let get_device_capability = torch.getattr("get_device_capability")?;
|
||||
get_device_capability.call0()?.extract()
|
||||
};
|
||||
|
||||
match pyo3::Python::with_gil(py_get_capability) {
|
||||
Ok((major, minor)) if major < 0 || minor < 0 => {
|
||||
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
|
||||
None
|
||||
}
|
||||
Ok((major, minor)) => Some((major as usize, minor as usize)),
|
||||
Err(err) => {
|
||||
tracing::warn!("Cannot determine GPU compute capability: {}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ use hf_hub::{
|
||||
};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
@ -26,6 +27,7 @@ use thiserror::Error;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||
|
||||
mod env_runtime;
|
||||
mod gpu;
|
||||
|
||||
fn get_config(
|
||||
model_id: &str,
|
||||
@ -65,6 +67,7 @@ fn get_config(
|
||||
}
|
||||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let compute_capability = gpu::get_cuda_capability();
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
@ -77,6 +80,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||
"paged"
|
||||
} else {
|
||||
"flashdecoding"
|
||||
};
|
||||
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
@ -89,10 +99,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
// flashinfer ?
|
||||
if attention.is_none() {
|
||||
tracing::info!(
|
||||
"Forcing flash decoding because model {} requires it",
|
||||
"Forcing attention to '{fallback_attention}' because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
attention = Some("flashdecoding".to_string());
|
||||
attention = Some(fallback_attention.to_string());
|
||||
}
|
||||
if fallback_attention == "paged" && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
Some("t5") => {}
|
||||
@ -101,8 +115,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
_ => {
|
||||
if attention.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some("flashdecoding".to_string());
|
||||
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some(fallback_attention.to_string());
|
||||
}
|
||||
if prefix_caching.is_none() {
|
||||
prefix_caching = Some("0".to_string());
|
||||
@ -110,8 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
}
|
||||
}
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
|
||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
|
||||
(prefix_caching, attention)
|
||||
}
|
||||
|
||||
@ -285,6 +301,22 @@ impl std::fmt::Display for Dtype {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum KVCacheDtype {
|
||||
#[clap(name = "fp8_e5m2")]
|
||||
Fp8e5m2,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for KVCacheDtype {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
KVCacheDtype::Fp8e5m2 => {
|
||||
write!(f, "fp8_e5m2")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum RopeScaling {
|
||||
Linear,
|
||||
@ -386,6 +418,12 @@ struct Args {
|
||||
#[clap(long, env, value_enum)]
|
||||
dtype: Option<Dtype>,
|
||||
|
||||
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||
/// the only supported value is `fp8_e5m2` on CUDA.
|
||||
#[clap(long, env, value_enum)]
|
||||
kv_cache_dtype: Option<KVCacheDtype>,
|
||||
|
||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||
/// contributed in a newer revision.
|
||||
@ -654,6 +692,7 @@ fn shard_manager(
|
||||
quantize: Option<Quantization>,
|
||||
speculate: Option<usize>,
|
||||
dtype: Option<Dtype>,
|
||||
kv_cache_dtype: Option<KVCacheDtype>,
|
||||
trust_remote_code: bool,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
@ -727,6 +766,11 @@ fn shard_manager(
|
||||
shard_args.push(dtype.to_string())
|
||||
}
|
||||
|
||||
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||
shard_args.push("--kv-cache-dtype".to_string());
|
||||
shard_args.push(kv_cache_dtype.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = revision {
|
||||
shard_args.push("--revision".to_string());
|
||||
@ -1283,6 +1327,7 @@ fn spawn_shards(
|
||||
let otlp_service_name = args.otlp_service_name.clone();
|
||||
let speculate = args.speculate;
|
||||
let dtype = args.dtype;
|
||||
let kv_cache_dtype = args.kv_cache_dtype;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
let master_port = args.master_port;
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
@ -1301,6 +1346,7 @@ fn spawn_shards(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
rank,
|
||||
@ -1793,14 +1839,37 @@ fn main() -> Result<(), LauncherError> {
|
||||
if adapter.contains('=') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let adapter = adapter.trim();
|
||||
|
||||
// check if adapter has more than 1 '@'
|
||||
if adapter.matches('@').count() > 1 {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
|
||||
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||
if let Some(caps) = re.captures(adapter) {
|
||||
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||
let revision = caps.get(3).map(|m| m.as_str());
|
||||
|
||||
download_convert_model(
|
||||
adapter,
|
||||
None,
|
||||
adapter_id,
|
||||
revision,
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
} else {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
23
nix/docker.nix
Normal file
23
nix/docker.nix
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
dockerTools,
|
||||
cacert,
|
||||
text-generation-inference,
|
||||
stream ? false,
|
||||
}:
|
||||
|
||||
let
|
||||
build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;
|
||||
in
|
||||
build {
|
||||
name = "tgi-docker";
|
||||
tag = "latest";
|
||||
config = {
|
||||
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
|
||||
Env = [
|
||||
"HF_HOME=/data"
|
||||
"PORT=80"
|
||||
];
|
||||
|
||||
};
|
||||
contents = [ cacert ];
|
||||
}
|
54
nix/impure-shell.nix
Normal file
54
nix/impure-shell.nix
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
mkShell,
|
||||
openssl,
|
||||
pkg-config,
|
||||
protobuf,
|
||||
python3,
|
||||
pyright,
|
||||
redocly,
|
||||
ruff,
|
||||
rust-bin,
|
||||
server,
|
||||
}:
|
||||
|
||||
mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
extensions = [
|
||||
"rust-analyzer"
|
||||
"rust-src"
|
||||
];
|
||||
})
|
||||
protobuf
|
||||
pyright
|
||||
redocly
|
||||
ruff
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
docker
|
||||
pip
|
||||
ipdb
|
||||
click
|
||||
pytest
|
||||
pytest-asyncio
|
||||
syrupy
|
||||
]);
|
||||
|
||||
inputsFrom = [ server ];
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
( cd server ; python -m pip install --no-dependencies -e . )
|
||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||
'';
|
||||
postShellHook = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
export PATH=$PATH:~/.cargo/bin
|
||||
'';
|
||||
}
|
41
nix/overlay.nix
Normal file
41
nix/overlay.nix
Normal file
@ -0,0 +1,41 @@
|
||||
final: prev: {
|
||||
# You can use this overlay to temporarily override packages for
|
||||
# development. For permanent overrides, it's better to do this in
|
||||
# our package flake:
|
||||
#
|
||||
# https://github.com/huggingface/text-generation-inference-nix
|
||||
#
|
||||
# Note that overriding packages that are in the transitive closure
|
||||
# of many other packages (e.g. transformers) will require a large
|
||||
# rebuild.
|
||||
|
||||
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
|
||||
(
|
||||
python-self: python-super: with python-self; {
|
||||
# Python package override example:
|
||||
# transformers = python-super.transformers.overrideAttrs (
|
||||
# _: _: {
|
||||
# src = final.fetchFromGitHub {
|
||||
# owner = "huggingface";
|
||||
# repo = "transformers";
|
||||
# rev = "2bd4d5897dc73e8b172832070a6f9e567a0df017";
|
||||
# hash = "sha256-JOIpKH9ssDEfI2Tf15e0iPKtThJwQ9GxMvRAnm+M2Pg=";
|
||||
# };
|
||||
# }
|
||||
# );
|
||||
}
|
||||
)
|
||||
];
|
||||
|
||||
# Non-python package override example:
|
||||
#
|
||||
# ripgrep = prev.ripgrep.overrideAttrs (
|
||||
# _: _: {
|
||||
# src = final.fetchFromGitHub {
|
||||
# owner = "BurntSushi";
|
||||
# repo = "ripgrep";
|
||||
# rev = "79cbe89deb1151e703f4d91b19af9cdcc128b765";
|
||||
# hash = "sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=";
|
||||
# };
|
||||
# });
|
||||
}
|
@ -13,6 +13,7 @@
|
||||
flash-attn,
|
||||
flash-attn-layer-norm,
|
||||
flash-attn-rotary,
|
||||
flash-attn-v1,
|
||||
grpc-interceptor,
|
||||
grpcio-reflection,
|
||||
grpcio-status,
|
||||
|
@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
||||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
pyo3 = { workspace = true }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -146,6 +146,7 @@ pub enum Config {
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Mllama,
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
@ -159,6 +160,7 @@ pub enum Config {
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Phi3,
|
||||
PhiMoe,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Paligemma(Paligemma),
|
||||
|
@ -29,7 +29,7 @@ impl ChatTemplate {
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
tracing::debug!("Loading template: {:#?}", template_str);
|
||||
tracing::debug!("Loading template: {}", template_str);
|
||||
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
|
@ -8,9 +8,11 @@ use crate::{
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
};
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use chat_template::ChatTemplate;
|
||||
use futures::future::try_join_all;
|
||||
use futures::Stream;
|
||||
use minijinja::ErrorKind;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
@ -87,7 +89,14 @@ impl Infer {
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let permit = self
|
||||
.clone()
|
||||
@ -107,9 +116,18 @@ impl Infer {
|
||||
})?;
|
||||
|
||||
let input_length = valid_request.input_length;
|
||||
let generation_stream = self.backend.schedule(valid_request)?;
|
||||
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
Ok((permit, input_length, generation_stream))
|
||||
// Wrap generation stream to update the backend health if the stream contains an error
|
||||
let final_stream = stream! {
|
||||
while let Some(response) = generation_stream.next().await {
|
||||
yield response.inspect_err(|_err| {
|
||||
self.backend_health.store(false, Ordering::SeqCst);
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok((permit, input_length, final_stream))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
@ -278,13 +296,6 @@ impl Infer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for generation responses
|
||||
pub(crate) type GenerateStreamResponse = (
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
|
@ -1,4 +0,0 @@
|
||||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::BackendV2;
|
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,10 @@ mod kserve;
|
||||
pub mod logging;
|
||||
|
||||
pub mod usage_stats;
|
||||
mod vertex;
|
||||
|
||||
use crate::infer::{Infer, InferError};
|
||||
use crate::server::prepare_chat_input;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
@ -54,32 +57,6 @@ impl std::str::FromStr for Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateVertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub inputs: String,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
enum VertexInstance {
|
||||
Generate(GenerateVertexInstance),
|
||||
Chat(ChatRequest),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexRequest {
|
||||
#[serde(rename = "instances")]
|
||||
pub instances: Vec<VertexInstance>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct VertexResponse {
|
||||
pub predictions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct HubModelInfo {
|
||||
@ -174,6 +151,7 @@ impl HubProcessorConfig {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||
@ -230,6 +208,7 @@ pub struct Info {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub(crate) struct GenerateParameters {
|
||||
/// Generate best_of sequences and return the one if the highest token logprobs.
|
||||
#[serde(default)]
|
||||
@ -774,6 +753,7 @@ impl ChatCompletionChunk {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
|
||||
pub(crate) struct ChatRequest {
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
@ -890,7 +870,82 @@ pub(crate) struct ChatRequest {
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||
infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
Ok((
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k: None,
|
||||
top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
},
|
||||
using_tools,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
struct StreamOptions {
|
||||
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||
#[schema(example = "true")]
|
||||
@ -984,6 +1039,7 @@ pub(crate) struct FunctionDefinition {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub(crate) struct Tool {
|
||||
// The type of the tool. Currently, only 'function' is supported.
|
||||
#[schema(example = "function")]
|
||||
|
@ -8,7 +8,8 @@ use crate::kserve::{
|
||||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
||||
use crate::vertex::vertex_compatibility;
|
||||
use crate::ChatTokenizeResponse;
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
@ -20,8 +21,7 @@ use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||
VertexResponse,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use crate::{ModelInfo, ModelsInfo};
|
||||
@ -149,63 +149,11 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
||||
)]
|
||||
async fn get_chat_tokenize(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
Json(chat): Json<ChatRequest>,
|
||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
..
|
||||
} = req;
|
||||
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
let generate_request = GenerateRequest {
|
||||
inputs,
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty: None,
|
||||
frequency_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens: max_tokens,
|
||||
return_full_text: None,
|
||||
stop: stop.unwrap_or_default(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: _grammar,
|
||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
};
|
||||
|
||||
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||
let input = generate_request.inputs.clone();
|
||||
let encoding = infer.tokenize(generate_request).await?;
|
||||
if let Some(encoding) = encoding {
|
||||
@ -1162,77 +1110,20 @@ async fn chat_completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
Json(chat): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
let ChatRequest {
|
||||
model,
|
||||
logprobs,
|
||||
max_tokens,
|
||||
messages,
|
||||
presence_penalty,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
stream_options,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
logprobs,
|
||||
..
|
||||
} = req;
|
||||
} = chat.clone();
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.try_into_generate(&infer)?;
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
};
|
||||
let logprobs = logprobs.unwrap_or_default();
|
||||
|
||||
// static values that will be returned in all cases
|
||||
let model_id = info.model_id.clone();
|
||||
@ -1385,186 +1276,6 @@ async fn chat_completions(
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tokens from Vertex request
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/vertex",
|
||||
request_body = VertexRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// check that theres at least one instance
|
||||
if req.instances.is_empty() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// Prepare futures for all instances
|
||||
let mut futures = Vec::with_capacity(req.instances.len());
|
||||
|
||||
for instance in req.instances.iter() {
|
||||
let generate_request = match instance {
|
||||
VertexInstance::Generate(instance) => GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||
details: true,
|
||||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
VertexInstance::Chat(instance) => {
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
..
|
||||
} = instance.clone();
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to prepare chat input: {}", e),
|
||||
error_type: "Input preparation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k: None,
|
||||
top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
futures.push(async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||
let results = futures::future::join_all(futures).await;
|
||||
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||
let predictions = predictions?;
|
||||
|
||||
let response = VertexResponse { predictions };
|
||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||
}
|
||||
|
||||
/// Tokenize inputs
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -2226,6 +1937,11 @@ async fn start(
|
||||
metrics::Unit::Count,
|
||||
"Maximum tokens for the current batch"
|
||||
);
|
||||
metrics::describe_gauge!(
|
||||
"tgi_batch_total_tokens",
|
||||
metrics::Unit::Count,
|
||||
"Maximum amount of tokens in total."
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_max_new_tokens",
|
||||
metrics::Unit::Count,
|
||||
@ -2318,7 +2034,8 @@ async fn start(
|
||||
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
use crate::VertexInstance;
|
||||
use crate::vertex::__path_vertex_compatibility;
|
||||
use crate::vertex::{VertexInstance, VertexRequest, VertexResponse};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
@ -2637,7 +2354,7 @@ pub enum WebServerError {
|
||||
|
||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||
|
||||
fn prepare_chat_input(
|
||||
pub(crate) fn prepare_chat_input(
|
||||
infer: &Infer,
|
||||
response_format: Option<GrammarType>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
|
@ -567,6 +567,7 @@ fn image_tokens(
|
||||
use HubPreprocessorConfig::*;
|
||||
match config {
|
||||
Idefics => "<image>".to_string(),
|
||||
Mllama => "<|image|>".to_string(),
|
||||
Idefics2(config) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
const IMAGE: &str = "<image>";
|
||||
@ -618,7 +619,7 @@ fn prepare_input(
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
360
router/src/vertex.rs
Normal file
360
router/src/vertex.rs
Normal file
@ -0,0 +1,360 @@
|
||||
use crate::infer::Infer;
|
||||
use crate::server::{generate_internal, ComputeType};
|
||||
use crate::{
|
||||
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
|
||||
StreamOptions, Tool, ToolChoice,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::instrument;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct GenerateVertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub inputs: String,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct VertexChat {
|
||||
messages: Vec<Message>,
|
||||
// Messages is ignored there.
|
||||
#[serde(default)]
|
||||
parameters: VertexParameters,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct VertexParameters {
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||
#[serde(default)]
|
||||
#[schema(example = "1.0")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// UNUSED
|
||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||
/// result in a ban or exclusive selection of the relevant token.
|
||||
#[serde(default)]
|
||||
pub logit_bias: Option<Vec<f32>>,
|
||||
|
||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||
/// output token returned in the content of message.
|
||||
#[serde(default)]
|
||||
#[schema(example = "false")]
|
||||
pub logprobs: Option<bool>,
|
||||
|
||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||
#[serde(default)]
|
||||
#[schema(example = "5")]
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
#[schema(example = "32")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// UNUSED
|
||||
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "2")]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||
/// increasing the model's likelihood to talk about new topics
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 0.1)]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
|
||||
#[serde(default = "bool::default")]
|
||||
pub stream: bool,
|
||||
|
||||
#[schema(nullable = true, example = 42)]
|
||||
pub seed: Option<u64>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||
/// lower values like 0.2 will make it more focused and deterministic.
|
||||
///
|
||||
/// We generally recommend altering this or `top_p` but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 1.0)]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 0.95)]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||
/// functions the model may generate JSON inputs for.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A prompt to be appended before the tools
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tool_choice: ToolChoice,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
///
|
||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub response_format: Option<GrammarType>,
|
||||
|
||||
/// A guideline to be used in the chat_template
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub guideline: Option<String>,
|
||||
|
||||
/// Options for streaming response. Only set this when you set stream: true.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
}
|
||||
|
||||
impl From<VertexChat> for ChatRequest {
|
||||
fn from(val: VertexChat) -> Self {
|
||||
Self {
|
||||
messages: val.messages,
|
||||
frequency_penalty: val.parameters.frequency_penalty,
|
||||
guideline: val.parameters.guideline,
|
||||
logit_bias: val.parameters.logit_bias,
|
||||
logprobs: val.parameters.logprobs,
|
||||
max_tokens: val.parameters.max_tokens,
|
||||
model: val.parameters.model,
|
||||
n: val.parameters.n,
|
||||
presence_penalty: val.parameters.presence_penalty,
|
||||
response_format: val.parameters.response_format,
|
||||
seed: val.parameters.seed,
|
||||
stop: val.parameters.stop,
|
||||
stream_options: val.parameters.stream_options,
|
||||
stream: val.parameters.stream,
|
||||
temperature: val.parameters.temperature,
|
||||
tool_choice: val.parameters.tool_choice,
|
||||
tool_prompt: val.parameters.tool_prompt,
|
||||
tools: val.parameters.tools,
|
||||
top_logprobs: val.parameters.top_logprobs,
|
||||
top_p: val.parameters.top_p,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum VertexInstance {
|
||||
Generate(GenerateVertexInstance),
|
||||
Chat(VertexChat),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct VertexRequest {
|
||||
#[serde(rename = "instances")]
|
||||
pub instances: Vec<VertexInstance>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct VertexResponse {
|
||||
pub predictions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Generate tokens from Vertex request
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/vertex",
|
||||
request_body = VertexRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
pub(crate) async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// check that theres at least one instance
|
||||
if req.instances.is_empty() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// Prepare futures for all instances
|
||||
let mut futures = Vec::with_capacity(req.instances.len());
|
||||
|
||||
for instance in req.instances.into_iter() {
|
||||
let generate_request = match instance {
|
||||
VertexInstance::Generate(instance) => GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||
details: true,
|
||||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
VertexInstance::Chat(instance) => {
|
||||
let chat_request: ChatRequest = instance.into();
|
||||
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
||||
chat_request.try_into_generate(&infer)?;
|
||||
generate_request
|
||||
}
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
futures.push(async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||
let results = futures::future::join_all(futures).await;
|
||||
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||
let predictions = predictions?;
|
||||
|
||||
let response = VertexResponse { predictions };
|
||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Message, MessageContent};
|
||||
|
||||
#[test]
|
||||
fn vertex_deserialization() {
|
||||
let string = serde_json::json!({
|
||||
|
||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||
"parameters": {
|
||||
"max_tokens": 128,
|
||||
"top_p": 0.95,
|
||||
"temperature": 0.7
|
||||
}
|
||||
});
|
||||
|
||||
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||||
|
||||
let string = serde_json::json!({
|
||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||
});
|
||||
|
||||
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||||
|
||||
let string = serde_json::json!({
|
||||
|
||||
"instances": [
|
||||
{
|
||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||
"parameters": {
|
||||
"max_tokens": 128,
|
||||
"top_p": 0.95,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
});
|
||||
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
|
||||
assert_eq!(
|
||||
request,
|
||||
VertexRequest {
|
||||
instances: vec![VertexInstance::Chat(VertexChat {
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
||||
name: None,
|
||||
},],
|
||||
parameters: VertexParameters {
|
||||
max_tokens: Some(128),
|
||||
top_p: Some(0.95),
|
||||
temperature: Some(0.7),
|
||||
..Default::default()
|
||||
}
|
||||
})]
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||
build-flash-attention-v2-rocm:
|
||||
if [ ! -d 'flash-attention-v2' ]; then \
|
||||
pip install -U packaging ninja --no-cache-dir && \
|
||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
|
||||
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||
fi
|
||||
|
@ -1,5 +1,5 @@
|
||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
||||
build-vllm-rocm:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
||||
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||
fi
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
@ -1,5 +1,17 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = []
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
setup(
|
||||
name="exllama_kernels",
|
||||
@ -13,6 +25,7 @@ setup(
|
||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
|
@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
|
2588
server/poetry.lock
generated
2588
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
|
||||
opentelemetry-exporter-otlp = "^1.25.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.46b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.19.1"
|
||||
sentencepiece = "^0.2"
|
||||
tokenizers = "^0.20"
|
||||
huggingface-hub = "^0.23"
|
||||
transformers = "^4.43"
|
||||
transformers = "^4.45"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
@ -47,10 +47,10 @@ marlin-kernels = [
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
rich = "^13.7.1"
|
||||
|
||||
|
@ -1,19 +1,19 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,19 +1,19 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,19 +1,19 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -30,6 +30,10 @@ class Dtype(str, Enum):
|
||||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
class KVCacheDtype(str, Enum):
|
||||
fp8_e5m2 = "fp8_e5m2"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
@ -38,6 +42,7 @@ def serve(
|
||||
quantize: Optional[Quantization] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
@ -97,6 +102,7 @@ def serve(
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
@ -114,6 +120,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
|
@ -1,29 +1,47 @@
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
import os
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .rocm import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .ipex import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||
from .kv_cache import KVCache
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"Seqlen",
|
||||
]
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
@ -65,5 +66,7 @@ else:
|
||||
max_k: int
|
||||
|
||||
def clamp(self, max):
|
||||
if SYSTEM == "rocm":
|
||||
return self
|
||||
raise NotImplementedError("Not implemented seqlen for paged")
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||
|
@ -287,16 +287,14 @@ elif V2:
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=None,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
@ -338,16 +336,30 @@ else:
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
q_data_type=dtype,
|
||||
page_size=page_size,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -119,7 +121,8 @@ def use_prefill_state(
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
@ -135,7 +138,8 @@ def use_prefill_state(
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -152,11 +156,13 @@ def create_decode_state(
|
||||
):
|
||||
"""Create a decode state."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
use_cuda_graph=False,
|
||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||
)
|
||||
|
||||
|
||||
@ -175,6 +181,7 @@ def create_decode_state_cuda_graphs(
|
||||
therefore stored as part of the state.
|
||||
"""
|
||||
workspace_buffer = get_workspace(device)
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
@ -182,7 +189,8 @@ def create_decode_state_cuda_graphs(
|
||||
paged_kv_indices_buffer=block_tables,
|
||||
paged_kv_indptr_buffer=block_tables_ptr,
|
||||
paged_kv_last_page_len_buffer=last_page_len,
|
||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||
)
|
||||
|
||||
|
||||
@ -196,7 +204,8 @@ def use_decode_state(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer decoding state to the given
|
||||
@ -231,7 +240,9 @@ def use_decode_state(
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
q_data_type=query_dtype,
|
||||
data_type=dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
|
||||
def attention(
|
||||
@ -79,3 +80,12 @@ def paged_attention(
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
121
server/text_generation_server/layers/attention/kv_cache.py
Normal file
121
server/text_generation_server/layers/attention/kv_cache.py
Normal file
@ -0,0 +1,121 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import reshape_and_cache
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""
|
||||
Key-value cache for attention layers.
|
||||
"""
|
||||
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_blocks: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
|
||||
if (
|
||||
dtype == torch.float8_e5m2
|
||||
and ATTENTION != "flashinfer"
|
||||
and SYSTEM != "cuda"
|
||||
):
|
||||
raise ValueError(
|
||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
||||
)
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
"""Get the key cache."""
|
||||
|
||||
return self.kv_cache[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Get the value cache."""
|
||||
|
||||
return self.kv_cache[1]
|
||||
|
||||
def store(
|
||||
self,
|
||||
*,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
|
||||
key_cache = self.kv_cache[0]
|
||||
value_cache = self.kv_cache[1]
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
# TODO: add scale
|
||||
key = key.to(key_cache.dtype)
|
||||
value = value.to(value_cache.dtype)
|
||||
if key_cache.dtype == torch.float8_e5m2:
|
||||
# Torch index_put does not support float8_e5m2 yet, so
|
||||
# put as raw data instead.
|
||||
key_cache = key_cache.view(torch.uint8)
|
||||
value_cache = value_cache.view(torch.uint8)
|
||||
key = key.view(torch.uint8)
|
||||
value = value.view(torch.uint8)
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
@ -8,13 +9,28 @@ from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
_PARTITION_SIZE_V1V2 = 512
|
||||
_PARTITION_SIZE_CUSTOM = 256
|
||||
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
if use_rocm_custom_paged_attn:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
except ImportError as e:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Custom Paged Attention not available. Complete error: {e}",
|
||||
)
|
||||
use_rocm_custom_paged_attn = False
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
@ -33,9 +49,7 @@ def reshape_and_cache(
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
@ -45,8 +59,9 @@ def paged_attention(
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
@ -65,11 +80,31 @@ def paged_attention(
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
use_rocm_custom_paged_attn
|
||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||
and (head_size == 128 or head_size == 64)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_s <= 32768
|
||||
)
|
||||
|
||||
if not use_custom:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
||||
else:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
input_lengths = seqlen.input_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
@ -78,9 +113,13 @@ def paged_attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
from vllm._C import ops
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
use_v1 = (
|
||||
max_s <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
and not use_custom
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
@ -112,6 +151,7 @@ def paged_attention(
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if not use_custom:
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
@ -130,6 +170,25 @@ def paged_attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
paged_attention_custom(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -156,7 +215,6 @@ if ENGINE != "triton":
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
else:
|
||||
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
name = torch.cuda.get_device_name(idx)
|
||||
if "MI210" not in name and "MI250" not in name:
|
||||
@ -173,13 +231,14 @@ if ENGINE == "ck":
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
@ -189,46 +248,57 @@ if ENGINE == "ck":
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)[0]
|
||||
|
||||
elif ENGINE == "triton":
|
||||
from .flash_attn_triton import triton_attention
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
causal,
|
||||
softmax_scale,
|
||||
)
|
||||
@ -236,3 +306,11 @@ elif ENGINE == "triton":
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||
).reshape(-1)
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
scale = scale.reshape(-1).expand(w.shape[0])
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = [
|
||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||
for p in prefixes
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.weight_scale is None:
|
||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||
self.weight_scale = self.weight_scale.contiguous()
|
||||
return get_fp8_linear().from_fp8(
|
||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||
)
|
||||
@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
# fbgemm needs float32 scales.
|
||||
scale = scale.float()
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module):
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||
return scale.reshape(-1).expand(shape[0])
|
||||
|
@ -1,12 +1,21 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
if ROCM_USE_SKINNY_GEMM:
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
raise ImportError(
|
||||
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
|
||||
)
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
@ -48,6 +57,14 @@ class FastLinearROCm(torch.nn.Module):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.cu_count = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
self.use_skinny_gemm = (
|
||||
ROCM_USE_SKINNY_GEMM
|
||||
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
@ -61,7 +78,11 @@ class FastLinearROCm(torch.nn.Module):
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||
if (
|
||||
self.use_skinny_gemm
|
||||
and inp.dtype == torch.float16
|
||||
and inp.shape[-1] % 8 == 0
|
||||
):
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
||||
@ -69,13 +90,16 @@ class FastLinearROCm(torch.nn.Module):
|
||||
inp = inp.view(-1, inp_shape[-1])
|
||||
batched = True
|
||||
|
||||
m, k = weight.shape[0], inp_shape[1]
|
||||
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||
if m > 8 and n <= 4:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||
_custom_C.LLMM1(weight, inp, out, 8)
|
||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||
_custom_C.LLMM1(weight, inp, out, 4)
|
||||
else:
|
||||
out = F.linear(inp, weight)
|
||||
|
@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
|
||||
|
||||
scales = permute_scales(scales)
|
||||
|
||||
is_full_k = not (desc_act and sharded_infeatures)
|
||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||
|
||||
return GPTQMarlinWeight(
|
||||
qweight=repacked,
|
||||
|
@ -1,15 +1,186 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||
from text_generation_server.layers.moe.gptq_marlin import (
|
||||
GPTQMarlinSparseMoELayer,
|
||||
can_use_marlin_moe_gemm,
|
||||
)
|
||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
UnquantizedWeight,
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
from .fused_moe_rocm import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
|
||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||
# class inheritance is whacky.
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MoELayer(Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
): ...
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DenseMoELayer(nn.Module):
|
||||
"""
|
||||
Layer for MoE that applies *all* experts to each tokens and then weights
|
||||
their outputs based on the calculated routing. This layer is much slower
|
||||
than `SparseMoELayer` and should only be used when no fused kernels are
|
||||
available (e.g. for unsupported quantizers).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
log_once(
|
||||
logger.info,
|
||||
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.n_experts = n_experts
|
||||
self.renormalize = renormalize
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
|
||||
if "gelu" in hidden_act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh"
|
||||
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in hidden_act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[hidden_act]
|
||||
|
||||
self.gate_proj = [
|
||||
TensorParallelColumnLinear.load(
|
||||
None,
|
||||
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
self.up_proj = [
|
||||
TensorParallelColumnLinear.load(
|
||||
None,
|
||||
prefix=f"{prefix}.{i}.{up_proj_name}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
self.down_proj = [
|
||||
TensorParallelRowLinear.load(
|
||||
None,
|
||||
prefix=f"{prefix}.{i}.{down_proj_name}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gating_output: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
if self.n_expert_group is not None and self.topk_group is not None:
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
x, gating_output, self.topk, self.renormalize
|
||||
)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
weights = torch.zeros(
|
||||
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
||||
|
||||
out = torch.zeros_like(x)
|
||||
for i in range(self.n_experts):
|
||||
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
||||
h = self.down_proj[i](h, reduce=False)
|
||||
out += h * weights[:, i].view(-1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SparseMoELayer(nn.Module):
|
||||
"""
|
||||
@ -39,14 +210,18 @@ class SparseMoELayer(nn.Module):
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||
cls = UnquantizedSparseMoELayer
|
||||
# Once we wire up GPTQ-Marlin MoE:
|
||||
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||
# cls = GPTQMarlinSparseMoELayer
|
||||
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||
cls = GPTQMarlinSparseMoELayer
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||
)
|
||||
|
||||
log_once(
|
||||
logger.info,
|
||||
"Using MoE layer wih fused gemm",
|
||||
)
|
||||
|
||||
self.moe = cls(
|
||||
n_expert_group=n_expert_group,
|
||||
n_experts=n_experts,
|
||||
@ -71,6 +246,12 @@ class SparseMoELayer(nn.Module):
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
)
|
||||
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||
# Once we wire up GPTQ-Marlin MoE:
|
||||
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
or (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
52
server/text_generation_server/layers/moe/fused_moe_rocm.py
Normal file
52
server/text_generation_server/layers/moe/fused_moe_rocm.py
Normal file
@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
215
server/text_generation_server/layers/moe/gptq_marlin.py
Normal file
215
server/text_generation_server/layers/moe/gptq_marlin.py
Normal file
@ -0,0 +1,215 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinWeight,
|
||||
GPTQMarlinWeightsLoader,
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
||||
else:
|
||||
fused_marlin_moe = None
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
|
||||
def can_use_marlin_moe_gemm(
|
||||
*,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and fused_marlin_moe is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and quant_method == "gptq"
|
||||
and sym
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinMoEWeight:
|
||||
qweight: torch.Tensor
|
||||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
perm: torch.Tensor
|
||||
is_full_k: bool
|
||||
|
||||
|
||||
class GPTQMarlinSparseMoELayer(nn.Module):
|
||||
"""
|
||||
MoE layer that uses a fused GPTQ-Marlin kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
|
||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
names=[gate_proj_name, up_proj_name],
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.down_proj = _load_expert_weights_row(
|
||||
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
|
||||
)
|
||||
|
||||
self.bits = weights.loader.bits
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj.qweight,
|
||||
w2=self.down_proj.qweight,
|
||||
g_idx1=self.gate_up_proj.g_idx,
|
||||
g_idx2=self.down_proj.g_idx,
|
||||
perm1=self.gate_up_proj.perm,
|
||||
perm2=self.down_proj.perm,
|
||||
w1_scale=self.gate_up_proj.scales,
|
||||
w2_scale=self.down_proj.scales,
|
||||
is_full_k1=self.gate_up_proj.is_full_k,
|
||||
is_full_k2=self.down_proj.is_full_k,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
num_bits=self.bits,
|
||||
)
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
names: List[str],
|
||||
weights: Weights,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
moe_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{name}" for name in names], 0
|
||||
)
|
||||
assert isinstance(weight, GPTQMarlinWeight)
|
||||
moe_weight = _pack_weight(
|
||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||
)
|
||||
assert moe_weight is not None
|
||||
return moe_weight
|
||||
|
||||
|
||||
def _load_expert_weights_row(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
moe_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_weights_row(
|
||||
f"{prefix}.{i}.{name}",
|
||||
)
|
||||
assert isinstance(weight, GPTQMarlinWeight)
|
||||
moe_weight = _pack_weight(
|
||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||
)
|
||||
assert moe_weight is not None
|
||||
return moe_weight
|
||||
|
||||
|
||||
def _pack_weight(
|
||||
*,
|
||||
n_experts: int,
|
||||
expert: int,
|
||||
moe_weight: Optional[GPTQMarlinMoEWeight],
|
||||
weight: GPTQMarlinWeight,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
if moe_weight is None:
|
||||
qweight = torch.empty(
|
||||
(n_experts,) + weight.qweight.shape,
|
||||
dtype=weight.qweight.dtype,
|
||||
device=weight.qweight.device,
|
||||
)
|
||||
qzeros = torch.empty(
|
||||
(n_experts,) + weight.qzeros.shape,
|
||||
dtype=weight.qzeros.dtype,
|
||||
device=weight.qzeros.device,
|
||||
)
|
||||
scales = torch.empty(
|
||||
(n_experts,) + weight.scales.shape,
|
||||
dtype=weight.scales.dtype,
|
||||
device=weight.scales.device,
|
||||
)
|
||||
g_idx = torch.empty(
|
||||
(n_experts,) + weight.g_idx.shape,
|
||||
dtype=weight.g_idx.dtype,
|
||||
device=weight.g_idx.device,
|
||||
)
|
||||
perm = torch.empty(
|
||||
(n_experts,) + weight.perm.shape,
|
||||
dtype=weight.perm.dtype,
|
||||
device=weight.perm.device,
|
||||
)
|
||||
|
||||
moe_weight = GPTQMarlinMoEWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
perm=perm,
|
||||
is_full_k=weight.is_full_k,
|
||||
)
|
||||
|
||||
moe_weight.qweight[expert] = weight.qweight
|
||||
moe_weight.qzeros[expert] = weight.qzeros
|
||||
moe_weight.scales[expert] = weight.scales
|
||||
moe_weight.g_idx[expert] = weight.g_idx
|
||||
moe_weight.perm[expert] = weight.perm
|
||||
|
||||
return moe_weight
|
@ -6,7 +6,9 @@ import torch.nn as nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
|
||||
|
||||
@ -52,6 +54,17 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
if SYSTEM == "rocm":
|
||||
return fused_moe(
|
||||
x,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
|
@ -166,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||
)
|
||||
|
||||
# if short_mscale and long_mscale are provided we need to scale the freqs
|
||||
# using the Phi3LongRoPEScaledRotaryEmbedding
|
||||
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
||||
short_mscale = rope_scaling["short_mscale"]
|
||||
long_mscale = rope_scaling["long_mscale"]
|
||||
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
short_mscale=short_mscale,
|
||||
long_mscale=long_mscale,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
|
||||
return SuRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
@ -287,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
@ -308,6 +323,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||
|
||||
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
short_inv_freq: torch.Tensor,
|
||||
long_inv_freq: torch.Tensor,
|
||||
max_position_embeddings: int,
|
||||
short_mscale: float,
|
||||
long_mscale: float,
|
||||
original_max_position_embeddings: int,
|
||||
):
|
||||
super(PositionRotaryEmbedding, self).__init__()
|
||||
self.short_inv_freq = short_inv_freq
|
||||
self.long_inv_freq = long_inv_freq
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.short_mscale = short_mscale
|
||||
self.long_mscale = long_mscale
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
|
||||
# cache
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||
|
||||
short_freqs = torch.outer(
|
||||
t[: self.original_max_position_embeddings],
|
||||
self.short_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
long_freqs = torch.outer(
|
||||
t[self.original_max_position_embeddings :],
|
||||
self.long_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
short_freqs = short_freqs * self.short_mscale
|
||||
long_freqs = long_freqs * self.long_mscale
|
||||
|
||||
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
|
||||
freqs[: self.original_max_position_embeddings] = short_freqs
|
||||
freqs[self.original_max_position_embeddings :] = long_freqs
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
@ -467,7 +539,6 @@ def apply_llama3_scaling(
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scaling_factor)
|
||||
else:
|
||||
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
|
@ -32,6 +32,9 @@ from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||
PhiMoEConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
@ -73,6 +76,7 @@ FLASH_ATTENTION = True
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||
FlashDeepseekV2ForCausalLM,
|
||||
DeepseekV2Config,
|
||||
@ -109,7 +113,11 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
@ -146,7 +154,7 @@ except ImportError as e:
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(IDEFICSSharded)
|
||||
__all__.append(IdeficsCausalLM)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
try:
|
||||
@ -237,6 +245,11 @@ class ModelType(enum.Enum):
|
||||
"name": "Phi",
|
||||
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||
}
|
||||
PHI_MOE = {
|
||||
"type": "phimoe",
|
||||
"name": "PhiMoe",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||
}
|
||||
BAICHUAN = {
|
||||
"type": "baichuan",
|
||||
"name": "Baichuan",
|
||||
@ -308,6 +321,12 @@ class ModelType(enum.Enum):
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||
"multimodal": True,
|
||||
}
|
||||
MLLAMA = {
|
||||
"type": "mllama",
|
||||
"name": "Mllama",
|
||||
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"multimodal": True,
|
||||
}
|
||||
|
||||
|
||||
__GLOBALS = locals()
|
||||
@ -323,6 +342,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
@ -334,6 +354,7 @@ def get_model(
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
compression_config = config_dict.get("compression_config", None)
|
||||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
if method in {"gptq", "awq", "exl2"}:
|
||||
@ -344,6 +365,23 @@ def get_model(
|
||||
quantize = "fp8"
|
||||
else:
|
||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||
elif compression_config is not None:
|
||||
# TODO: at some point we should probably fully parse the compression
|
||||
# configuration to know which parameters are compressed.
|
||||
config_groups = compression_config.get("config_groups")
|
||||
if config_groups is not None:
|
||||
for _, group in config_groups.items():
|
||||
weights_config = group.get("weights")
|
||||
if weights_config is not None:
|
||||
if (
|
||||
weights_config["type"] == "float"
|
||||
and weights_config["num_bits"] == 8
|
||||
):
|
||||
log_master(
|
||||
logger.info, "Auto selecting quantization method fp8"
|
||||
)
|
||||
quantize = "fp8"
|
||||
break
|
||||
|
||||
if dtype is None:
|
||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
@ -366,6 +404,13 @@ def get_model(
|
||||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
if kv_cache_dtype is None:
|
||||
kv_cache_dtype = dtype
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
|
||||
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
@ -526,6 +571,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
default_dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DeepseekV2Config,
|
||||
@ -580,6 +626,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
@ -631,6 +678,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -666,6 +714,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -704,6 +753,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=GPTNeoXConfig,
|
||||
@ -737,6 +787,31 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == PHI_MOE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
config_class=PhiMoEConfig,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -768,7 +843,6 @@ def get_model(
|
||||
)
|
||||
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
print(f">>> model_type: {model_type}")
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -777,6 +851,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -800,6 +875,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -825,6 +901,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -851,6 +928,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -875,6 +953,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Dbrx works better in bfloat16.
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -905,6 +984,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
@ -923,6 +1003,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
@ -950,6 +1031,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -974,6 +1056,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -998,6 +1081,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -1024,6 +1108,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -1068,7 +1153,7 @@ def get_model(
|
||||
)
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION:
|
||||
return IDEFICSSharded(
|
||||
return IdeficsCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1078,6 +1163,22 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == MLLAMA:
|
||||
if FLASH_ATTENTION:
|
||||
return MllamaCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
batch_class=MllamaCausalLMBatch,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
@ -1087,6 +1188,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
@ -1104,6 +1206,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -1122,6 +1225,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
@ -1194,6 +1298,7 @@ def get_model_with_lora_adapters(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
@ -1207,6 +1312,7 @@ def get_model_with_lora_adapters(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
@ -39,6 +38,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -290,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -307,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
@ -28,8 +28,8 @@ if SYSTEM != "ipex":
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -329,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -346,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user