Merge branch 'main' into gpt_awq_4

This commit is contained in:
Wang, Yi A 2024-10-08 07:55:38 -04:00
commit 92fa7ac7e9
139 changed files with 14200 additions and 4450 deletions

View File

@ -4,3 +4,4 @@ server/transformers
server/flash-attention
cmake-build-debug/
cmake-build-release/
Dockerfile*

View File

@ -45,7 +45,7 @@ jobs:
export dockerfile="Dockerfile"
export label_extension=""
export docker_devices=""
export runs_on="aws-g6-12xlarge-plus-priv"
export runs_on="aws-g6-12xl-plus-priv-cache"
export platform=""
;;
rocm)

View File

@ -42,6 +42,7 @@ jobs:
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install
run: |
sudo apt update
sudo apt install python3.11-dev -y
make install-cpu
- name: Run server tests
@ -57,3 +58,6 @@ jobs:
- name: Run Rust tests
run: |
cargo test
- name: Run Rust tests with google feature
run: |
cargo test --features google

3
.gitignore vendored
View File

@ -3,9 +3,8 @@ target
router/tokenizer.json
*__pycache__*
backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files
*.hip

View File

@ -23,9 +23,11 @@ docs/openapi.json:
- '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example'

677
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +1,26 @@
[workspace]
members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
"backends/trtllm",
"backends/client",
"launcher"
"launcher",
"router"
]
default-members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
# "backends/trtllm",
"backends/client",
"launcher"
"launcher",
"router"
]
resolver = "2"
[workspace.package]
version = "2.2.1-dev0"
version = "2.3.2-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
@ -31,6 +33,7 @@ metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release]
incremental = true

View File

@ -40,7 +40,6 @@ COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
@ -258,7 +257,7 @@ COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2

24
Dockerfile.nix Normal file
View File

@ -0,0 +1,24 @@
# Build the image and get out the docker file:
#
# docker build -t tgi-nix-builder -f Dockerfile.nix
# docker run --log-driver=none tgi-nix-builder | docker load
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
RUN cachix use text-generation-inference
WORKDIR /root
ADD . .
RUN nix build .
RUN mkdir /tmp/nix-store-closure
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
FROM ubuntu:24.04
WORKDIR /app
# Copy /nix/store
COPY --from=builder /tmp/nix-store-closure /nix/store
COPY --from=builder /root/result /app
RUN ldconfig
CMD ["ldconfig", "/app/bin/text-generation-launcher"]

View File

@ -41,7 +41,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
FROM rocm/dev-ubuntu-22.04:6.2 AS base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
@ -50,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipblaslt-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3.11-dev && \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
@ -100,41 +101,132 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
pip install .
RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}
# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
&& cd build/release \
&& make package
FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
&& cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
&& cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
FROM base as build_pytorch
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
USE_ROCM="1" \
BUILD_TEST="0" \
USE_FBGEMM="0" \
USE_NNPACK="0" \
USE_QNNPACK="0" \
USE_XNNPACK="0" \
USE_FLASH_ATTENTION="1" \
USE_MEM_EFF_ATTENTION="0"
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
RUN git clone ${PYTORCH_REPO} pytorch \
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
&& pip install -r requirements.txt --no-cache-dir \
&& python tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
FROM base AS install_deps
FROM base AS kernel-builder
ARG COMMON_WORKDIR
# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
# RCCL needs to be installed twice
&& dpkg -i /install/*.deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y triton \
&& pip install /install/*.whl; \
fi
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y amdsmi \
&& pip install /install/*.whl;
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi
FROM install_deps AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
@ -174,7 +266,7 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build
FROM base AS base-copy
FROM install_deps AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -224,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model
ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
```
And then you can make requests like

View File

@ -1,17 +0,0 @@
{
mkPoetryApplication,
pkg-config,
protobuf,
openssl,
}:
mkPoetryApplication {
# name = "text-generation-server";
projectDir = ./server;
# nativeBuildInputs = [ pkg-config ];
# buildInputs = [ openssl.dev protobuf ];
}

75
backends/v2/Cargo.toml Normal file
View File

@ -0,0 +1,75 @@
[package]
name = "text-generation-router-v2"
description = "Text Generation Webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router-v2"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
[features]
default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"]

19
backends/v2/build.rs Normal file
View File

@ -0,0 +1,19 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/client/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/client/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

506
backends/v2/src/backend.rs Normal file
View File

@ -0,0 +1,506 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV2 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
impl BackendV2 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
) -> Self {
// Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client.clone(),
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
));
Self {
queue,
batching_task_notifier,
client,
}
}
}
#[async_trait]
impl Backend for BackendV2 {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
}
async fn health(&self, current_health: bool) -> bool {
if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self {
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v2_finish_reason {
crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}

View File

@ -0,0 +1,257 @@
/// Single shard Client
use crate::client::pb;
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v2::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v2 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -0,0 +1,68 @@
//! Text Generation gRPC client library
use async_trait::async_trait;
use thiserror::Error;
use tonic::transport;
use tonic::Status;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod grpc_client;
mod sharded_client;
pub use grpc_client::Client;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -0,0 +1,252 @@
/// Multi shard Client
use crate::client::{ClientError, Result};
use crate::client::{Health, ShardInfo};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::InfoResponse;
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
self.clone().prefill(batch).await?;
Ok(())
}
}

141
backends/v2/src/lib.rs Normal file
View File

@ -0,0 +1,141 @@
mod backend;
mod client;
mod queue;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV2;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV2, BackendInfo), V2Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V2Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V2Error::Connection)?;
// server is running on v2
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V2Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V2Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
};
let backend = BackendV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub enum V2Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}

212
backends/v2/src/main.rs Normal file
View File

@ -0,0 +1,212 @@
use clap::{Parser, Subcommand};
use text_generation_router::{server, usage_stats};
use text_generation_router_v2::{connect_backend, V2Error};
use thiserror::Error;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
command,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
api_key,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
} = args;
if let Some(Commands::PrintSchema) = command {
use utoipa::OpenApi;
let api_doc = text_generation_router::server::ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
};
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` must be > 0".to_string(),
));
}
}
let (backend, _backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
)
.await?;
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] V2Error),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -1,14 +1,14 @@
use crate::infer::{InferError, InferStreamResponse};
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use text_generation_client::ChunksToString;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -218,7 +218,7 @@ impl State {
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
next_batch_span.follows_from(Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
@ -404,6 +404,7 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tracing::info_span;
fn default_entry() -> (
@ -415,7 +416,9 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0,
add_special_tokens: true,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {

View File

@ -100,6 +100,7 @@ pub async fn connect_backend(
.map_err(V3Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
let backend_info = BackendInfo {
waiting_served_ratio,

View File

@ -364,7 +364,7 @@ impl State {
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break;
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
@ -436,6 +436,12 @@ impl State {
batch_entries.insert(id, entry);
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);

View File

@ -1,10 +1,22 @@
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::hash::{Hash, Hasher};
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
fn hash(slice: &[u32]) -> u64 {
assert!(!slice.is_empty());
if slice.len() == 1 {
slice[0] as u64
} else {
let mut s = std::hash::DefaultHasher::new();
slice.hash(&mut s);
s.finish()
}
}
pub struct RadixAllocator {
allocation_id: u64,
@ -44,6 +56,10 @@ impl RadixAllocator {
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
tracing::debug!(
"Free blocks {} need {n_blocks_needed}",
self.free_blocks.len()
);
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
@ -94,6 +110,9 @@ impl Allocator for RadixAllocator {
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
tracing::debug!("Block size {}", self.block_size);
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
@ -211,7 +230,6 @@ struct RadixAllocation {
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
BlockTokenCountMismatch,
}
pub type NodeId = DefaultKey;
@ -268,7 +286,9 @@ impl RadixTrie {
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) {
if key.len() >= self.block_size {
let node_key = hash(&key[..self.block_size]);
if let Some(&child_id) = node.children.get(&node_key) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
@ -280,6 +300,7 @@ impl RadixTrie {
node_id = self.find_(child_id, key, blocks);
}
}
}
node_id
}
@ -344,9 +365,11 @@ impl RadixTrie {
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
tracing::debug!("Evicting in search of {n_blocks}");
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len();
let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!(
@ -368,8 +391,11 @@ impl RadixTrie {
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
let truncate_blocks = node.blocks.len() - blocks_needed;
let truncate_tokens = truncate_blocks * self.block_size;
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id));
break;
}
@ -400,11 +426,10 @@ impl RadixTrie {
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() * self.block_size {
return Err(TrieError::BlockTokenCountMismatch);
}
assert_eq!(tokens.len(), blocks.len() * self.block_size);
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
let node_key = hash(&tokens[..self.block_size]);
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
self.update_access_time(child_id);
let child = self
.nodes
@ -452,14 +477,15 @@ impl RadixTrie {
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len);
let prefix_blocks = prefix_len / self.block_size;
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0];
let node_key = hash(&node.key[..self.block_size]);
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
@ -484,7 +510,7 @@ impl RadixTrie {
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = key[0];
let first = hash(&key[..self.block_size]);
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
@ -496,10 +522,10 @@ impl RadixTrie {
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() {
if parent.children.insert(hash, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
@ -517,7 +543,9 @@ impl RadixTrie {
);
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
let node_key = hash(&node.key[..self.block_size]);
parent.children.remove(&node_key);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
node
@ -571,7 +599,7 @@ impl RadixTrie {
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u32, NodeId>,
children: HashMap<u64, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,

View File

@ -16,7 +16,6 @@ path = "src/main.rs"
[dependencies]
average = "0.14"
clap = { version = "4.4.5", features = ["derive", "env"] }
crossterm = "0.27"
float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0"
@ -25,7 +24,7 @@ text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
ratatui = "0.28.1"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = { workspace = true }

View File

@ -7,7 +7,7 @@
</div>
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
and powered by [Ratatui](https://github.com/ratatui/ratatui).
## Install

View File

@ -1,16 +1,15 @@
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
use crate::generation::{Decode, Message, Prefill};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
use tui::backend::Backend;
use tui::layout::{Alignment, Constraint, Direction, Layout};
use tui::style::{Color, Modifier, Style};
use tui::text::{Line, Span};
use tui::widgets::{
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
use ratatui::style::{Color, Modifier, Style};
use ratatui::text::{Line, Span};
use ratatui::widgets::{
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
};
use tui::{symbols, Frame};
use ratatui::{symbols, Frame};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
/// TUI powered App
pub(crate) struct App {
@ -153,7 +152,7 @@ impl App {
}
/// Render frame
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
pub fn render(&mut self, f: &mut Frame) {
let batch_progress =
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
let run_progress =
@ -172,7 +171,7 @@ impl App {
]
.as_ref(),
)
.split(f.size());
.split(f.area());
// Top row horizontal layout
let top = Layout::default()
@ -239,7 +238,7 @@ impl App {
f.render_widget(helper, row5[0]);
// Batch tabs
let titles = self
let titles: Vec<Line> = self
.data
.batch_size
.iter()

View File

@ -1,5 +1,5 @@
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
use crossterm::event;
use ratatui::crossterm::event;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, mpsc};

View File

@ -6,13 +6,13 @@ mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use ratatui::backend::CrosstermBackend;
use ratatui::crossterm::ExecutableCommand;
use ratatui::Terminal;
use std::io;
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
use tui::Terminal;
/// Run benchmarking app
#[allow(clippy::too_many_arguments)]
@ -50,9 +50,9 @@ pub async fn run(
};
// Initialize terminal properties
crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?;
ratatui::crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
// Initialize terminal
let mut terminal = {
@ -128,9 +128,9 @@ pub async fn run(
let _ = shutdown_guard_receiver.recv().await;
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
ratatui::crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
let parameters_table = table::parameters_table(
tokenizer_name,

View File

@ -28,11 +28,17 @@ class ToolCall(BaseModel):
function: dict
class Chunk(BaseModel):
type: str
text: Optional[str] = None
image_url: Any = None
class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str] = None
content: Optional[Union[str, List[Chunk]]] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
@ -168,7 +174,7 @@ class ChatCompletionComplete(BaseModel):
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
finish_reason: Optional[str]
# Usage details of the chat completion
usage: Optional[Any] = None
@ -191,6 +197,7 @@ class ChatCompletionChunk(BaseModel):
model: str
system_fingerprint: str
choices: List[Choice]
usage: Optional[Any] = None
class Parameters(BaseModel):

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.2.1-dev0"
"version": "2.3.2-dev0"
},
"paths": {
"/": {
@ -742,6 +742,14 @@
},
"system_fingerprint": {
"type": "string"
},
"usage": {
"allOf": [
{
"$ref": "#/components/schemas/Usage"
}
],
"nullable": true
}
}
},
@ -937,6 +945,14 @@
"stream": {
"type": "boolean"
},
"stream_options": {
"allOf": [
{
"$ref": "#/components/schemas/StreamOptions"
}
],
"nullable": true
},
"temperature": {
"type": "number",
"format": "float",
@ -1912,6 +1928,19 @@
}
}
},
"StreamOptions": {
"type": "object",
"required": [
"include_usage"
],
"properties": {
"include_usage": {
"type": "boolean",
"description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.",
"example": "true"
}
}
},
"StreamResponse": {
"type": "object",
"required": [

View File

@ -10,7 +10,7 @@ This diagram shows well there are these separate components:
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
The router and the model server can be two different machines, they do not need to be deployed together.

View File

@ -36,7 +36,13 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
To specify model revision, use `adapter_id@revision`, as follows:
```bash
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
```
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
```bash
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
@ -72,6 +78,22 @@ curl 127.0.0.1:3000/generate \
}'
```
If you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:
```json
curl 127.0.0.1:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "Hello who are you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "myadapter"
}
}'
```
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
An updated tutorial with detailed examples will be published soon. Stay tuned!

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-rocm \
ghcr.io/huggingface/text-generation-inference:2.3.1-rocm \
--model-id $model
```
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
## Custom PagedAttention
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
## Unsupported features
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-xpu \
--model-id $model --cuda-graphs 0
```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-cpu \
--model-id $model --cuda-graphs 0
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 \
ghcr.io/huggingface/text-generation-inference:2.3.1 \
--model-id $model
```

View File

@ -11,10 +11,19 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 \
ghcr.io/huggingface/text-generation-inference:2.3.1 \
--model-id $model
```
<Tip>
If you want to serve gated or private models, which provide
controlled access to sensitive or proprietary content, refer to
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
for detailed instructions.
</Tip>
### Supported hardware
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.

View File

@ -55,7 +55,9 @@ Options:
## QUANTIZE
```shell
--quantize <QUANTIZE>
Whether you want the model to be quantized
Quantization method to use for the model. It is not necessary to specify this option for pre-quantized models, since the quantization method is read from the model configuration.
Marlin kernels will be used automatically for GPTQ/AWQ models.
[env: QUANTIZE=]
@ -87,6 +89,15 @@ Options:
[env: DTYPE=]
[possible values: float16, bfloat16]
```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
```
## TRUST_REMOTE_CODE
```shell

View File

@ -20,6 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
- [Phi](https://huggingface.co/microsoft/phi-1_5)
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
@ -34,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:

View File

@ -479,11 +479,11 @@
"systems": "systems_6"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1726560853,
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github"
},
"original": {
@ -497,11 +497,11 @@
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1726560853,
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github"
},
"original": {
@ -718,11 +718,11 @@
},
"nixpkgs_6": {
"locked": {
"lastModified": 1724915739,
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"lastModified": 1727675176,
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
"type": "github"
},
"original": {
@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1726021481,
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
"lastModified": 1727836133,
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
"type": "github"
},
"original": {
@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1725950569,
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
"lastModified": 1728029332,
"narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "98049f853346ca780b81fee730715c90d33ac2b4",
"type": "github"
},
"original": {
"owner": "danieldk",
"repo": "tgi-nix",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"type": "github"
}
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
@ -37,6 +37,7 @@
overlays = [
rust-overlay.overlays.default
tgi-nix.overlays.default
(import nix/overlay.nix)
];
};
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
@ -67,8 +68,37 @@
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
in
{
checks = {
rust =
with pkgs;
rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [
clippy
pkg-config
protobuf
python3
rustfmt
];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
};
};
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = pure;
@ -84,10 +114,11 @@
test = mkShell {
buildInputs =
[
# benchmark
# launcher
# router
benchmark
launcher
router
server
client
openssl.dev
pkg-config
cargo
@ -102,52 +133,17 @@
pre-commit
ruff
]);
};
impure = mkShell {
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pyright
pytest
pytest-asyncio
ruff
syrupy
]);
impure = callPackage ./nix/impure-shell.nix { inherit server; };
inputsFrom = [ server ];
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
};
};
packages.default = pkgs.writeShellApplication {
packages = rec {
default = pkgs.writeShellApplication {
name = "text-generation-inference";
runtimeInputs = [
server
@ -157,6 +153,16 @@
${launcher}/bin/text-generation-launcher "$@"
'';
};
dockerImage = pkgs.callPackage nix/docker.nix {
text-generation-inference = default;
};
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
text-generation-inference = default;
stream = true;
};
};
}
);
}

View File

@ -336,12 +336,14 @@ def launcher(event_loop):
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -374,6 +376,9 @@ def launcher(event_loop):
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
@ -401,6 +406,8 @@ def launcher(event_loop):
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
@ -431,12 +438,14 @@ def launcher(event_loop):
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
):
port = random.randint(8000, 10_000)
@ -452,6 +461,9 @@ def launcher(event_loop):
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
@ -491,6 +503,8 @@ def launcher(event_loop):
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
if HF_TOKEN is not None:
env["HF_TOKEN"] = HF_TOKEN
@ -522,6 +536,7 @@ def launcher(event_loop):
devices=devices,
volumes=volumes,
ports={"80/tcp": port},
healthcheck={"timeout": int(10 * 1e9)},
shm_size="1G",
)
@ -582,7 +597,6 @@ def generate_multi():
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))

View File

@ -0,0 +1,206 @@
[
{
"choices": [
{
"delta": {
"content": "**",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "Deep",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Learning",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": ":",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " An",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Overview",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "**\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "================================",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "=====",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "\n\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 40,
"total_tokens": 50
}
}
]

View File

@ -24,13 +24,13 @@
"tokens": [
{
"id": 1736,
"logprob": -2.03125,
"logprob": -2.109375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.8671875,
"logprob": -1.90625,
"special": false,
"text": "\n\n"
},
@ -42,48 +42,48 @@
},
{
"id": 2121,
"logprob": -1.8125,
"logprob": -1.796875,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.24121094,
"logprob": -0.24511719,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.100097656,
"logprob": -0.09326172,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.9453125,
"logprob": -0.95703125,
"special": false,
"text": " is"
},
{
"id": 476,
"logprob": -1.703125,
"id": 1671,
"logprob": -1.5859375,
"special": false,
"text": " a"
"text": " used"
},
{
"id": 4551,
"logprob": -2.453125,
"id": 577,
"logprob": -0.39257812,
"special": false,
"text": " document"
"text": " to"
},
{
"id": 674,
"logprob": -0.796875,
"id": 3853,
"logprob": -1.25,
"special": false,
"text": " that"
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is a document that"
"generated_text": " form\n\nThe test request form is used to request"
}

View File

@ -11,12 +11,12 @@
},
{
"id": 2015,
"logprob": -9.640625,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.375,
"logprob": -10.3671875,
"text": " request"
}
],
@ -24,19 +24,19 @@
"tokens": [
{
"id": 604,
"logprob": -0.2824707,
"logprob": -0.28271484,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -0.19030762,
"logprob": -0.18493652,
"special": false,
"text": " the"
},
{
"id": 16819,
"logprob": -1.4892578,
"logprob": -1.4804688,
"special": false,
"text": " detection"
},
@ -46,44 +46,44 @@
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": -2.0195312,
"special": false,
"text": " the"
},
{
"id": 8566,
"logprob": 0.0,
"special": false,
"text": " presence"
},
{
"id": 689,
"logprob": -0.16491699,
"special": false,
"text": " or"
},
{
"id": 14862,
"logprob": 0.0,
"special": false,
"text": " absence"
},
{
"id": 576,
"logprob": -0.9946289,
"special": false,
"text": " of"
},
{
"id": 671,
"logprob": -0.5263672,
"logprob": -2.1738281,
"special": false,
"text": " an"
},
{
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " RNA"
},
{
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " virus"
},
{
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " in"
},
{
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " patients"
},
{
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " who"
}
],
"top_tokens": null
},
"generated_text": "Test request for the detection of the presence or absence of an"
"generated_text": "Test request for the detection of an RNA virus in patients who"
}

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}

View File

@ -0,0 +1,57 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 374,
"logprob": -22.96875,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"special": false,
"text": " "
},
{
"id": 128001,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
]

View File

@ -0,0 +1,114 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1445312,
"text": "What"
},
{
"id": 349,
"logprob": -1.4648438,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.6005859,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39526367,
"text": "?"
},
{
"id": 13,
"logprob": -0.640625,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18774414,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.96484375,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.003168106,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16540527,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.08886719,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.75878906,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5703125,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11242676,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.7939453,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17102051,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34326172,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 24871,
"logprob": -17.234375,
"text": "descent"
},
{
"id": 28804,
"logprob": -7.4375,
"text": "?"
},
{
"id": 13,
"logprob": -0.8046875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.33032227,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 1313,
"logprob": -2.3613281,
"special": false,
"text": "It"
},
{
"id": 3969,
"logprob": -0.7285156,
"special": false,
"text": " seems"
},
{
"id": 298,
"logprob": -1.3466797,
"special": false,
"text": " to"
},
{
"id": 528,
"logprob": 0.0,
"special": false,
"text": " me"
},
{
"id": 28725,
"logprob": -1.6757812,
"special": false,
"text": ","
},
{
"id": 369,
"logprob": 0.0,
"special": false,
"text": " that"
},
{
"id": 513,
"logprob": -1.1269531,
"special": false,
"text": " if"
},
{
"id": 368,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 28742,
"logprob": -2.4921875,
"special": false,
"text": "'"
},
{
"id": 267,
"logprob": 0.0,
"special": false,
"text": "re"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nIt seems to me, that if you're"
}

View File

@ -0,0 +1,458 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1445312,
"text": "What"
},
{
"id": 349,
"logprob": -1.4648438,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.6005859,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39526367,
"text": "?"
},
{
"id": 13,
"logprob": -0.640625,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18774414,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.96484375,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.003168106,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16369629,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.0881958,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.76708984,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.57373047,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11291504,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.79589844,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.1694336,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34350586,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1445312,
"text": "What"
},
{
"id": 349,
"logprob": -1.4677734,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.6015625,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39453125,
"text": "?"
},
{
"id": 13,
"logprob": -0.6435547,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18713379,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.9628906,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.0032176971,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16540527,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.08898926,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.765625,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5708008,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11401367,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.7963867,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17028809,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34326172,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.140625,
"text": "What"
},
{
"id": 349,
"logprob": -1.4658203,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6796875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.5898438,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.3955078,
"text": "?"
},
{
"id": 13,
"logprob": -0.64501953,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18493652,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.9580078,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.0032176971,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16552734,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.08874512,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.75878906,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5703125,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11236572,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.79541016,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17102051,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34326172,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1328125,
"text": "What"
},
{
"id": 349,
"logprob": -1.4658203,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6796875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.5947266,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39648438,
"text": "?"
},
{
"id": 13,
"logprob": -0.6464844,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18688965,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.9609375,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.003168106,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16601562,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.088134766,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.7597656,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5708008,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11291504,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.7944336,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17102051,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34399414,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
}
]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8808594,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37280273,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.26098633,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017137527,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2695312,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9238281,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48828125,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -0.34838867,
"special": false,
"text": "\n"
},
{
"id": 13940,
"logprob": -0.38916016,
"special": false,
"text": "``"
},
{
"id": 28832,
"logprob": 0.0,
"special": false,
"text": "`"
},
{
"id": 3371,
"logprob": -1.2529297,
"special": false,
"text": "json"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 28751,
"logprob": 0.0,
"special": false,
"text": "{"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 2287,
"logprob": 0.0,
"special": false,
"text": " "
},
{
"id": 345,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 3134,
"logprob": -0.640625,
"special": false,
"text": "request"
}
],
"top_tokens": null
},
"generated_text": "Test request\n```json\n{\n \"request"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
}
]

View File

@ -0,0 +1,109 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"logprob": null,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"logprob": 0.0,
"special": false,
"text": "!"
},
{
"id": 739,
"logprob": 0.0,
"special": false,
"text": " It"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 29915,
"logprob": 0.0,
"special": false,
"text": "'"
},
{
"id": 276,
"logprob": -0.9838867,
"special": false,
"text": "re"
},
{
"id": 3211,
"logprob": 0.0,
"special": false,
"text": " address"
},
{
"id": 292,
"logprob": 0.0,
"special": false,
"text": "ing"
},
{
"id": 263,
"logprob": -0.15124512,
"special": false,
"text": " a"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
}

View File

@ -0,0 +1,438 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7128906,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.71484375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7192383,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}
]

View File

@ -0,0 +1,106 @@
[
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}
]

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727556016,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}

View File

@ -3,9 +3,7 @@ import requests
import json
from aiohttp import ClientSession
from text_generation.types import (
Completion,
)
from text_generation.types import Completion, ChatCompletionChunk
@pytest.fixture(scope="module")
@ -50,6 +48,114 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream_options": {"include_usage": True},
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(

View File

@ -0,0 +1,77 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
await flash_llama_fp8_kv_cache_handle.health(300)
return flash_llama_fp8_kv_cache_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_all_params(
flash_llama_fp8_kv_cache, response_snapshot
):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_load(
flash_llama_fp8_kv_cache, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,75 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_handle(launcher):
with launcher("mistralai/Mixtral-8x7B-v0.1", num_shard=8) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral(flash_mixtral_handle):
await flash_mixtral_handle.health(300)
return flash_mixtral_handle.client
@pytest.mark.skip(reason="requires > 4 shards")
@pytest.mark.asyncio
async def test_flash_mixtral(flash_mixtral, response_snapshot):
response = await flash_mixtral.generate(
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is an optimization algorithm used to minimize"
)
assert response == response_snapshot
@pytest.mark.skip(reason="requires > 4 shards")
@pytest.mark.asyncio
async def test_flash_mixtral_all_params(flash_mixtral, response_snapshot):
response = await flash_mixtral.generate(
"What is gradient descent?\n\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nIt seems to me, that if you're"
)
assert response == response_snapshot
@pytest.mark.skip(reason="requires > 4 shards")
@pytest.mark.asyncio
async def test_flash_mixtral_load(flash_mixtral, generate_load, response_snapshot):
responses = await generate_load(
flash_mixtral, "What is gradient descent?\n\n", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is an optimization algorithm used to minimize"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,60 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_gptq_handle(launcher):
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
await flash_mixtral_gptq_handle.health(300)
return flash_mixtral_gptq_handle.client
@pytest.mark.asyncio
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_load(
flash_mixtral_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,75 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi35_moe_handle(launcher):
with launcher(
"microsoft/Phi-3.5-MoE-instruct",
num_shard=4,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi35_moe(flash_phi35_moe_handle):
await flash_phi35_moe_handle.health(300)
return flash_phi35_moe_handle.client
@pytest.mark.asyncio
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
responses = await generate_load(
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,105 @@
import pytest
import base64
import asyncio
@pytest.fixture(scope="module")
def mllama_handle(launcher):
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mllama(mllama_handle):
await mllama_handle.health(300)
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
assert response.usage == {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60,
}
assert (
response.choices[0].message.content
== "In a bustling city, a chicken named Cluck"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
futures = [
mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
for i in range(4)
]
responses = await asyncio.gather(*futures)
generated_texts = [response.choices[0].message.content for response in responses]
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
"meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle:
yield handle
@ -208,7 +210,7 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1
assert count == 48
assert count == 28
assert response == response_snapshot

View File

@ -12,11 +12,13 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0"
pyo3 = { workspace = true }
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
thiserror = "1.0.59"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
regex = "1.11.0"
[dev-dependencies]
float_eq = "1.0.1"

21
launcher/src/gpu.rs Normal file
View File

@ -0,0 +1,21 @@
pub fn get_cuda_capability() -> Option<(usize, usize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
let torch = py.import_bound("torch.cuda")?;
let get_device_capability = torch.getattr("get_device_capability")?;
get_device_capability.call0()?.extract()
};
match pyo3::Python::with_gil(py_get_capability) {
Ok((major, minor)) if major < 0 || minor < 0 => {
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
None
}
Ok((major, minor)) => Some((major as usize, minor as usize)),
Err(err) => {
tracing::warn!("Cannot determine GPU compute capability: {}", err);
None
}
}
}

View File

@ -5,6 +5,7 @@ use hf_hub::{
};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -26,6 +27,7 @@ use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
mod gpu;
fn get_config(
model_id: &str,
@ -65,6 +67,7 @@ fn get_config(
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
@ -77,6 +80,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string());
}
}
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
"paged"
} else {
"flashdecoding"
};
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
@ -89,10 +99,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
"Forcing attention to '{fallback_attention}' because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some("flashdecoding".to_string());
attention = Some(fallback_attention.to_string());
}
if fallback_attention == "paged" && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
prefix_caching = Some("0".to_string());
}
}
Some("t5") => {}
@ -101,8 +115,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some(fallback_attention.to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
@ -110,8 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
(prefix_caching, attention)
}
@ -285,6 +301,22 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
@ -367,7 +399,11 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Whether you want the model to be quantized.
/// Quantization method to use for the model. It is not necessary to specify this option
/// for pre-quantized models, since the quantization method is read from the model
/// configuration.
///
/// Marlin kernels will be used automatically for GPTQ/AWQ models.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
@ -382,6 +418,12 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
@ -650,6 +692,7 @@ fn shard_manager(
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
kv_cache_dtype: Option<KVCacheDtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
@ -723,6 +766,11 @@ fn shard_manager(
shard_args.push(dtype.to_string())
}
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_args.push("--revision".to_string());
@ -1034,6 +1082,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
Ok(log) => log.trace(),
// For interactive debugging ?
Err(_) => {
if LevelFilter::current() >= tracing::Level::DEBUG {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
@ -1045,6 +1094,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
}
}
}
}
}
fn find_num_shards(
@ -1277,6 +1327,7 @@ fn spawn_shards(
let otlp_service_name = args.otlp_service_name.clone();
let speculate = args.speculate;
let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
@ -1295,6 +1346,7 @@ fn spawn_shards(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
rank,
@ -1787,14 +1839,37 @@ fn main() -> Result<(), LauncherError> {
if adapter.contains('=') {
continue;
}
let adapter = adapter.trim();
// check if adapter has more than 1 '@'
if adapter.matches('@').count() > 1 {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
// capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
if let Some(caps) = re.captures(adapter) {
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
let revision = caps.get(3).map(|m| m.as_str());
download_convert_model(
adapter,
None,
adapter_id,
revision,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
}
}

21
nix/client.nix Normal file
View File

@ -0,0 +1,21 @@
{
buildPythonPackage,
poetry-core,
huggingface-hub,
pydantic,
}:
buildPythonPackage {
name = "text-generation";
src = ../clients/python;
pyproject = true;
build-system = [ poetry-core ];
dependencies = [
huggingface-hub
pydantic
];
}

23
nix/docker.nix Normal file
View File

@ -0,0 +1,23 @@
{
dockerTools,
cacert,
text-generation-inference,
stream ? false,
}:
let
build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;
in
build {
name = "tgi-docker";
tag = "latest";
config = {
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
Env = [
"HF_HOME=/data"
"PORT=80"
];
};
contents = [ cacert ];
}

54
nix/impure-shell.nix Normal file
View File

@ -0,0 +1,54 @@
{
mkShell,
openssl,
pkg-config,
protobuf,
python3,
pyright,
redocly,
ruff,
rust-bin,
server,
}:
mkShell {
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
pyright
redocly
ruff
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pytest
pytest-asyncio
syrupy
]);
inputsFrom = [ server ];
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
}

41
nix/overlay.nix Normal file
View File

@ -0,0 +1,41 @@
final: prev: {
# You can use this overlay to temporarily override packages for
# development. For permanent overrides, it's better to do this in
# our package flake:
#
# https://github.com/huggingface/text-generation-inference-nix
#
# Note that overriding packages that are in the transitive closure
# of many other packages (e.g. transformers) will require a large
# rebuild.
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
(
python-self: python-super: with python-self; {
# Python package override example:
# transformers = python-super.transformers.overrideAttrs (
# _: _: {
# src = final.fetchFromGitHub {
# owner = "huggingface";
# repo = "transformers";
# rev = "2bd4d5897dc73e8b172832070a6f9e567a0df017";
# hash = "sha256-JOIpKH9ssDEfI2Tf15e0iPKtThJwQ9GxMvRAnm+M2Pg=";
# };
# }
# );
}
)
];
# Non-python package override example:
#
# ripgrep = prev.ripgrep.overrideAttrs (
# _: _: {
# src = final.fetchFromGitHub {
# owner = "BurntSushi";
# repo = "ripgrep";
# rev = "79cbe89deb1151e703f4d91b19af9cdcc128b765";
# hash = "sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=";
# };
# });
}

View File

@ -13,6 +13,7 @@
flash-attn,
flash-attn-layer-norm,
flash-attn-rotary,
flash-attn-v1,
grpc-interceptor,
grpcio-reflection,
grpcio-status,
@ -21,6 +22,7 @@
loguru,
mamba-ssm,
marlin-kernels,
moe-kernels,
opentelemetry-api,
opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc,
@ -88,6 +90,7 @@ buildPythonPackage {
loguru
mamba-ssm
marlin-kernels
moe-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc

View File

@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] }
csv = "1.3.0"
ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
pyo3 = { workspace = true }
[build-dependencies]

View File

@ -146,6 +146,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Mllama,
Idefics2(Idefics2),
Ssm,
GptBigcode,
@ -159,6 +160,7 @@ pub enum Config {
#[serde(rename = "phi-msft")]
PhiMsft,
Phi3,
PhiMoe,
Llama,
Baichuan,
Paligemma(Paligemma),

View File

@ -29,7 +29,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)

View File

@ -8,9 +8,11 @@ use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token,
};
use async_stream::stream;
use async_trait::async_trait;
use chat_template::ChatTemplate;
use futures::future::try_join_all;
use futures::Stream;
use minijinja::ErrorKind;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@ -87,7 +89,14 @@ impl Infer {
pub(crate) async fn generate_stream<'a>(
&'a self,
request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> {
) -> Result<
(
OwnedSemaphorePermit,
u32, // input_length
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
@ -107,9 +116,18 @@ impl Infer {
})?;
let input_length = valid_request.input_length;
let generation_stream = self.backend.schedule(valid_request)?;
let mut generation_stream = self.backend.schedule(valid_request)?;
Ok((permit, input_length, generation_stream))
// Wrap generation stream to update the backend health if the stream contains an error
let final_stream = stream! {
while let Some(response) = generation_stream.next().await {
yield response.inspect_err(|_err| {
self.backend_health.store(false, Ordering::SeqCst);
})
}
};
Ok((permit, input_length, final_stream))
}
/// Tokenizer the input
@ -278,13 +296,6 @@ impl Infer {
}
}
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
#[derive(Debug)]
pub struct GeneratedText {
pub text: String,

View File

@ -1,4 +0,0 @@
mod queue;
mod scheduler;
pub(crate) use scheduler::BackendV2;

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,10 @@ mod kserve;
pub mod logging;
pub mod usage_stats;
mod vertex;
use crate::infer::{Infer, InferError};
use crate::server::prepare_chat_input;
use serde::{Deserialize, Serialize};
use tracing::warn;
use utoipa::ToSchema;
@ -54,32 +57,6 @@ impl std::str::FromStr for Attention {
}
}
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {
@ -174,6 +151,7 @@ impl HubProcessorConfig {
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/).
@ -230,6 +208,7 @@ pub struct Info {
}
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) struct GenerateParameters {
/// Generate best_of sequences and return the one if the highest token logprobs.
#[serde(default)]
@ -684,6 +663,7 @@ pub(crate) struct ChatCompletionChunk {
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Option<Usage>,
}
#[derive(Clone, Serialize, ToSchema)]
@ -732,6 +712,7 @@ impl ChatCompletionChunk {
created: u64,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
usage: Option<Usage>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
@ -766,11 +747,13 @@ impl ChatCompletionChunk {
logprobs,
finish_reason,
}],
usage,
}
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
@ -880,6 +863,93 @@ pub(crate) struct ChatRequest {
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
impl ChatRequest {
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = self;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, using_tools) = prepare_chat_input(
infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
Ok((
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
},
using_tools,
))
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(Debug, PartialEq))]
struct StreamOptions {
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
#[schema(example = "true")]
include_usage: bool,
}
pub fn default_tool_prompt() -> String {
@ -969,6 +1039,7 @@ pub(crate) struct FunctionDefinition {
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) struct Tool {
// The type of the tool. Currently, only 'function' is supported.
#[schema(example = "function")]
@ -1472,6 +1543,27 @@ mod tests {
let textmsg: TextMessage = message.into();
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
}
#[test]
fn test_chat_stream_options() {
let json = json!({
"model": "",
"stream_options": {"include_usage": true},
"messages": [{
"role": "user",
"content": "Hello"
}]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert!(matches!(
request.stream_options,
Some(StreamOptions {
include_usage: true
})
));
}
#[test]
fn openai_output() {
let message = OutputMessage::ChatMessage(TextMessage {

View File

@ -8,20 +8,20 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token,
TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
};
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo};
@ -149,63 +149,11 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
)]
async fn get_chat_tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
Json(chat): Json<ChatRequest>,
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
..
} = req;
let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, _using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
let generate_request = GenerateRequest {
inputs,
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: max_tokens,
return_full_text: None,
stop: stop.unwrap_or_default(),
truncate: None,
watermark: false,
details: false,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: _grammar,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
};
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
if let Some(encoding) = encoding {
@ -1162,76 +1110,20 @@ async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>,
Json(chat): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
logprobs,
max_tokens,
messages,
presence_penalty,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
stream_options,
logprobs,
..
} = req;
} = chat.clone();
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: req.top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
};
let logprobs = logprobs.unwrap_or_default();
// static values that will be returned in all cases
let model_id = info.model_id.clone();
@ -1265,6 +1157,28 @@ async fn chat_completions(
(content, None)
};
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
event
.json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
@ -1274,7 +1188,8 @@ async fn chat_completions(
tool_calls,
current_time,
logprobs,
stream_token.details.map(|d| d.finish_reason.format(true)),
finish_reason,
usage,
),
))
.unwrap_or_else(|e| {
@ -1361,186 +1276,6 @@ async fn chat_completions(
}
}
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.clone();
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, _using_tools) = match prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
) {
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to prepare chat input: {}", e),
error_type: "Input preparation error".to_string(),
}),
));
}
};
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
}
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
/// Tokenize inputs
#[utoipa::path(
post,
@ -1664,6 +1399,7 @@ StreamDetails,
ErrorResponse,
GrammarType,
Usage,
StreamOptions,
DeltaToolCall,
ToolType,
Tool,
@ -2136,9 +1872,12 @@ async fn start(
.unwrap();
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
// .unwrap();
let prom_handle = builder
.install_recorder()
.expect("failed to install metrics recorder");
// See: https://github.com/metrics-rs/metrics/issues/467#issuecomment-2022755151
let (recorder, _) = builder
.build()
.expect("failed to build prometheus recorder");
let prom_handle = recorder.handle();
metrics::set_global_recorder(recorder).expect("Failed to set global recorder");
// Metrics descriptions
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
@ -2198,6 +1937,11 @@ async fn start(
metrics::Unit::Count,
"Maximum tokens for the current batch"
);
metrics::describe_gauge!(
"tgi_batch_total_tokens",
metrics::Unit::Count,
"Maximum amount of tokens in total."
);
metrics::describe_histogram!(
"tgi_request_max_new_tokens",
metrics::Unit::Count,
@ -2290,7 +2034,8 @@ async fn start(
#[cfg(feature = "google")]
{
use crate::VertexInstance;
use crate::vertex::__path_vertex_compatibility;
use crate::vertex::{VertexInstance, VertexRequest, VertexResponse};
#[derive(OpenApi)]
#[openapi(
@ -2609,7 +2354,7 @@ pub enum WebServerError {
type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input(
pub(crate) fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,

View File

@ -567,6 +567,7 @@ fn image_tokens(
use HubPreprocessorConfig::*;
match config {
Idefics => "<image>".to_string(),
Mllama => "<|image|>".to_string(),
Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
@ -618,7 +619,7 @@ fn prepare_input(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;

360
router/src/vertex.rs Normal file
View File

@ -0,0 +1,360 @@
use crate::infer::Infer;
use crate::server::{generate_internal, ComputeType};
use crate::{
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
StreamOptions, Tool, ToolChoice,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexChat {
messages: Vec<Message>,
// Messages is ignored there.
#[serde(default)]
parameters: VertexParameters,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexParameters {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: Option<String>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: ToolChoice,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
impl From<VertexChat> for ChatRequest {
fn from(val: VertexChat) -> Self {
Self {
messages: val.messages,
frequency_penalty: val.parameters.frequency_penalty,
guideline: val.parameters.guideline,
logit_bias: val.parameters.logit_bias,
logprobs: val.parameters.logprobs,
max_tokens: val.parameters.max_tokens,
model: val.parameters.model,
n: val.parameters.n,
presence_penalty: val.parameters.presence_penalty,
response_format: val.parameters.response_format,
seed: val.parameters.seed,
stop: val.parameters.stop,
stream_options: val.parameters.stream_options,
stream: val.parameters.stream,
temperature: val.parameters.temperature,
tool_choice: val.parameters.tool_choice,
tool_prompt: val.parameters.tool_prompt,
tools: val.parameters.tools,
top_logprobs: val.parameters.top_logprobs,
top_p: val.parameters.top_p,
}
}
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
#[serde(untagged)]
pub(crate) enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(VertexChat),
}
#[derive(Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
pub(crate) async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.into_iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let chat_request: ChatRequest = instance.into();
let (generate_request, _using_tools): (GenerateRequest, bool) =
chat_request.try_into_generate(&infer)?;
generate_request
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Message, MessageContent};
#[test]
fn vertex_deserialization() {
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"instances": [
{
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
}
]
});
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
assert_eq!(
request,
VertexRequest {
instances: vec![VertexInstance::Chat(VertexChat {
messages: vec![Message {
role: "user".to_string(),
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
name: None,
},],
parameters: VertexParameters {
max_tokens: Some(128),
top_p: Some(0.95),
temperature: Some(0.7),
..Default::default()
}
})]
}
);
}
}

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
build-flash-attention-v2-cuda:
pip install -U packaging wheel
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
build-flash-attention-v2-rocm:
if [ ! -d 'flash-attention-v2' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
fi

View File

@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

View File

@ -1,5 +1,17 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = []
extra_cflags = []
if torch.version.hip:
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_compile_args = {
"cxx": extra_cflags,
"nvcc": extra_cuda_cflags,
}
setup(
name="exllama_kernels",
@ -13,6 +25,7 @@ setup(
"exllama_kernels/cuda_func/q4_matmul.cu",
"exllama_kernels/cuda_func/q4_matrix.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},

View File

@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = ["-lineinfo", "-O3"]
extra_cflags = []
if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_compile_args = {
"cxx": extra_cflags,
"nvcc": extra_cuda_cflags,
}

2652
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.19.1"
sentencepiece = "^0.2"
tokenizers = "^0.20"
huggingface-hub = "^0.23"
transformers = "^4.43"
transformers = "^4.45"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
@ -46,6 +46,12 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
rich = "^13.7.1"
[tool.poetry.extras]
@ -53,6 +59,7 @@ torch = ["torch"]
accelerate = ["accelerate"]
bnb = ["bitsandbytes"]
marlin = ["marlin-kernels"]
moe = ["moe-kernels"]
peft = ["peft"]
quantize = ["texttable", "datasets", "accelerate"]
outlines = ["outlines"]

View File

@ -1,19 +1,19 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,19 +1,19 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,19 +1,19 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -30,6 +30,10 @@ class Dtype(str, Enum):
bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e5m2 = "fp8_e5m2"
@app.command()
def serve(
model_id: str,
@ -38,6 +42,7 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@ -97,6 +102,7 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
if dtype is not None and quantize not in {
None,
"bitsandbytes",
@ -114,6 +120,7 @@ def serve(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,

View File

@ -1,29 +1,47 @@
from text_generation_server.utils.import_utils import SYSTEM
import os
from text_generation_server.utils.import_utils import SYSTEM
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
]

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
@ -65,5 +66,7 @@ else:
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -287,16 +287,14 @@ elif V2:
else:
def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=None,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
@ -338,16 +336,30 @@ else:
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
True,
causal,
False,
0,
None,
)
return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
q_data_type=dtype,
page_size=page_size,
window_left=window_left,
)
yield
finally:
@ -119,7 +121,8 @@ def use_prefill_state(
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
@ -135,7 +138,8 @@ def use_prefill_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
@ -152,11 +156,13 @@ def create_decode_state(
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@ -175,6 +181,7 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
@ -182,7 +189,8 @@ def create_decode_state_cuda_graphs(
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@ -196,7 +204,8 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
@ -231,7 +240,9 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
data_type=dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:

Some files were not shown because too many files have changed in this diff Show More