Merge branch 'huggingface:main' into fix/dockerfile-triton

This commit is contained in:
Yaser Jaradeh 2025-01-13 11:44:53 +01:00 committed by GitHub
commit ad4dcb68df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 1551 additions and 477 deletions

214
Cargo.lock generated
View File

@ -456,18 +456,18 @@ dependencies = [
[[package]] [[package]]
name = "bit-set" name = "bit-set"
version = "0.5.3" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
dependencies = [ dependencies = [
"bit-vec", "bit-vec",
] ]
[[package]] [[package]]
name = "bit-vec" name = "bit-vec"
version = "0.6.3" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
[[package]] [[package]]
name = "bit_field" name = "bit_field"
@ -502,6 +502,12 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "borrow-or-share"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
[[package]] [[package]]
name = "built" name = "built"
version = "0.7.5" version = "0.7.5"
@ -1139,6 +1145,15 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "email_address"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "encode_unicode" name = "encode_unicode"
version = "0.3.6" version = "0.3.6"
@ -1196,12 +1211,13 @@ dependencies = [
[[package]] [[package]]
name = "fancy-regex" name = "fancy-regex"
version = "0.11.0" version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298"
dependencies = [ dependencies = [
"bit-set", "bit-set",
"regex", "regex-automata 0.4.9",
"regex-syntax 0.8.5",
] ]
[[package]] [[package]]
@ -1247,6 +1263,17 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "fluent-uri"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
dependencies = [
"borrow-or-share",
"ref-cast",
"serde",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@ -1285,9 +1312,9 @@ dependencies = [
[[package]] [[package]]
name = "fraction" name = "fraction"
version = "0.13.1" version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
dependencies = [ dependencies = [
"lazy_static", "lazy_static",
"num", "num",
@ -1414,10 +1441,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi", "wasi",
"wasm-bindgen",
] ]
[[package]] [[package]]
@ -1573,7 +1598,7 @@ dependencies = [
"native-tls", "native-tls",
"num_cpus", "num_cpus",
"rand", "rand",
"reqwest", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
@ -2051,15 +2076,6 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iso8601"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.10.5" version = "0.10.5"
@ -2128,32 +2144,27 @@ dependencies = [
[[package]] [[package]]
name = "jsonschema" name = "jsonschema"
version = "0.17.1" version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" checksum = "74d8eb539cdb4222da29bb658cc9881aa2477b33fb1a74c5c31450395fc1a4b2"
dependencies = [ dependencies = [
"ahash", "ahash",
"anyhow", "base64 0.22.1",
"base64 0.21.7",
"bytecount", "bytecount",
"clap 4.5.21", "email_address",
"fancy-regex", "fancy-regex",
"fraction", "fraction",
"getrandom", "idna",
"iso8601",
"itoa", "itoa",
"memchr",
"num-cmp", "num-cmp",
"once_cell", "once_cell",
"parking_lot",
"percent-encoding", "percent-encoding",
"regex", "referencing",
"reqwest", "regex-syntax 0.8.5",
"reqwest 0.12.9",
"serde", "serde",
"serde_json", "serde_json",
"time", "uuid-simd",
"url",
"uuid",
] ]
[[package]] [[package]]
@ -2984,6 +2995,12 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "outref"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a"
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@ -3557,6 +3574,39 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "ref-cast"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.89",
]
[[package]]
name = "referencing"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "093a875008827c0ae15c746189966e162faa05bf347719d06302c548ac63630f"
dependencies = [
"ahash",
"fluent-uri",
"once_cell",
"percent-encoding",
"serde_json",
]
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.11.1" version = "1.11.1"
@ -3641,6 +3691,42 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "reqwest"
version = "0.12.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.5.1",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"once_cell",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"windows-registry",
]
[[package]] [[package]]
name = "rgb" name = "rgb"
version = "0.8.50" version = "0.8.50"
@ -4220,6 +4306,9 @@ name = "sync_wrapper"
version = "1.0.2" version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]] [[package]]
name = "synstructure" name = "synstructure"
@ -4404,7 +4493,7 @@ dependencies = [
"once_cell", "once_cell",
"pyo3", "pyo3",
"regex", "regex",
"reqwest", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
@ -4445,7 +4534,7 @@ dependencies = [
"pyo3", "pyo3",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
"sysinfo", "sysinfo",
@ -4493,7 +4582,7 @@ dependencies = [
"prost-build", "prost-build",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
"slotmap", "slotmap",
@ -4544,7 +4633,7 @@ dependencies = [
"prost-build", "prost-build",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
"slotmap", "slotmap",
@ -5298,6 +5387,17 @@ dependencies = [
"syn 2.0.89", "syn 2.0.89",
] ]
[[package]]
name = "uuid-simd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
dependencies = [
"outref",
"uuid",
"vsimd",
]
[[package]] [[package]]
name = "v_frame" name = "v_frame"
version = "0.3.8" version = "0.3.8"
@ -5349,6 +5449,12 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "vsimd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
[[package]] [[package]]
name = "walkdir" name = "walkdir"
version = "2.5.0" version = "2.5.0"
@ -5558,6 +5664,36 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.45.0" version = "0.45.0"

View File

@ -234,6 +234,7 @@ FROM kernel-builder AS vllm-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-vllm Makefile COPY server/Makefile-vllm Makefile
RUN pip install setuptools_scm
# Build specific version of vllm # Build specific version of vllm
RUN make build-vllm-rocm RUN make build-vllm-rocm
@ -267,6 +268,15 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build RUN python setup.py build
FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python setup.py install
FROM install_deps AS base-copy FROM install_deps AS base-copy
# Text Generation Inference base env # Text Generation Inference base env
@ -289,6 +299,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server

View File

@ -97,11 +97,10 @@ ENV HF_HOME=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
@ -119,6 +118,9 @@ ENV CCL_ZE_IPC_EXCHANGE=sockets
#ENV TORCH_LLM_ALLREDUCE=1 #ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router

View File

@ -1,7 +1,7 @@
<div align="center"> <div align="center">
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0"> <a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
<img width=560 width=315 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png"> <img width=560 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png">
</a> </a>
# Text Generation Inference # Text Generation Inference
@ -141,8 +141,8 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri
For example, if you want to serve the gated Llama V2 model variants: For example, if you want to serve the gated Llama V2 model variants:
1. Go to https://huggingface.co/settings/tokens 1. Go to https://huggingface.co/settings/tokens
2. Copy your cli READ token 2. Copy your CLI READ token
3. Export `HF_TOKEN=<your cli READ token>` 3. Export `HF_TOKEN=<your CLI READ token>`
or with Docker: or with Docker:
@ -157,7 +157,7 @@ docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/da
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by [`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
`PyTorch` to do distributed training/inference. `text-generation-inference` make `PyTorch` to do distributed training/inference. `text-generation-inference` makes
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models. use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
@ -196,7 +196,7 @@ Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with T
You can also opt to install `text-generation-inference` locally. You can also opt to install `text-generation-inference` locally.
First clone the repository and change directoy into it: First clone the repository and change directory into it:
```shell ```shell
git clone https://github.com/huggingface/text-generation-inference git clone https://github.com/huggingface/text-generation-inference
@ -213,7 +213,7 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
conda create -n text-generation-inference python=3.11 conda create -n text-generation-inference python=3.11
conda activate text-generation-inference conda activate text-generation-inference
#using pyton venv #using python venv
python3 -m venv .venv python3 -m venv .venv
source .venv/bin/activate source .venv/bin/activate
``` ```

View File

@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" } grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.28.0" }
metrics = { workspace = true } metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true } metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"

View File

@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" } grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.28.0" }
metrics = { workspace = true } metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true } metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"

View File

@ -187,8 +187,6 @@ In addition to the grammar parameter, we've also introduced a set of tools and f
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
```json ```json
curl localhost:3000/v1/chat/completions \ curl localhost:3000/v1/chat/completions \
-X POST \ -X POST \

View File

@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)

View File

@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1732218602, "lastModified": 1736436388,
"narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=", "narHash": "sha256-CIyxVPpM9RrSwthNT/4DQ10YPk/uwzP7AeE83kBNsrE=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "f79638ac4e420e661321261744e745a3a747e182", "rev": "5103c3fb1f9ad1fd33b6e09ff05e957884b112d5",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -354,6 +354,7 @@ def launcher(event_loop):
kv_cache_dtype: Optional[str] = None, kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_input_tokens: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None, lora_adapters: Optional[List[str]] = None,
@ -402,6 +403,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_input_tokens:
args.append("--max-input-tokens")
args.append(str(max_input_tokens))
if max_batch_prefill_tokens: if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens") args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens)) args.append(str(max_batch_prefill_tokens))

View File

@ -32,7 +32,7 @@
}, },
{ {
"id": 1101, "id": 1101,
"logprob": -1.0947266, "logprob": -1.0136719,
"special": false, "special": false,
"text": " also" "text": " also"
}, },
@ -56,13 +56,13 @@
}, },
{ {
"id": 4009, "id": 4009,
"logprob": -0.15563965, "logprob": -0.21923828,
"special": false, "special": false,
"text": " network" "text": " network"
}, },
{ {
"id": 477, "id": 477,
"logprob": -1.4003906, "logprob": -1.4824219,
"special": false, "special": false,
"text": " or" "text": " or"
} }

View File

@ -8,7 +8,7 @@
"tokens": [ "tokens": [
{ {
"id": 1939, "id": 1939,
"logprob": -2.2675781, "logprob": -2.2460938,
"special": false, "special": false,
"text": "?\n\n" "text": "?\n\n"
}, },
@ -20,13 +20,13 @@
}, },
{ {
"id": 20909, "id": 20909,
"logprob": -0.37695312, "logprob": -0.48608398,
"special": false, "special": false,
"text": " Learning" "text": " Learning"
}, },
{ {
"id": 4102, "id": 4102,
"logprob": -1.9316406, "logprob": -2.265625,
"special": false, "special": false,
"text": " " "text": " "
}, },
@ -38,25 +38,13 @@
}, },
{ {
"id": 458, "id": 458,
"logprob": -0.80859375, "logprob": -0.6328125,
"special": false, "special": false,
"text": " an" "text": " an"
}, },
{
"id": 3082,
"logprob": -1.4541016,
"special": false,
"text": " area"
},
{
"id": 315,
"logprob": 0.0,
"special": false,
"text": " of"
},
{ {
"id": 20443, "id": 20443,
"logprob": -0.5136719, "logprob": -0.1796875,
"special": false, "special": false,
"text": " artificial" "text": " artificial"
}, },
@ -65,9 +53,21 @@
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " intelligence" "text": " intelligence"
},
{
"id": 320,
"logprob": -0.37695312,
"special": false,
"text": " ("
},
{
"id": 15469,
"logprob": 0.0,
"special": false,
"text": "AI"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" "generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI"
} }

View File

@ -9,61 +9,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -82,61 +82,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -155,61 +155,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -228,61 +228,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }

View File

@ -44,7 +44,7 @@
}, },
{ {
"id": 38397, "id": 38397,
"logprob": -0.12695312, "logprob": 0.0,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },

View File

@ -14,60 +14,60 @@
}, },
{ {
"id": 573, "id": 573,
"logprob": -0.18493652, "logprob": -0.19030762,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 16819, "id": 16819,
"logprob": -1.4804688, "logprob": -1.4863281,
"special": false, "special": false,
"text": " detection" "text": " detection"
}, },
{ {
"id": 576, "id": 576,
"logprob": -0.7011719, "logprob": -0.7089844,
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": -2.0410156,
"special": false,
"text": " the"
},
{
"id": 8566,
"logprob": 0.0,
"special": false,
"text": " presence"
},
{
"id": 689,
"logprob": -0.16491699,
"special": false,
"text": " or"
},
{
"id": 14862,
"logprob": 0.0,
"special": false,
"text": " absence"
},
{
"id": 576,
"logprob": -0.9970703,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 671, "id": 671,
"logprob": -2.1738281, "logprob": -0.5292969,
"special": false, "special": false,
"text": " an" "text": " an"
},
{
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " RNA"
},
{
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " virus"
},
{
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " in"
},
{
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " patients"
},
{
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " who"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request for the detection of an RNA virus in patients who" "generated_text": "Test request for the detection of the presence or absence of an"
} }

View File

@ -8,7 +8,7 @@
"tokens": [ "tokens": [
{ {
"id": 2284, "id": 2284,
"logprob": -0.296875, "logprob": -0.31323242,
"special": false, "special": false,
"text": "():" "text": "():"
}, },
@ -38,13 +38,13 @@
}, },
{ {
"id": 10914, "id": 10914,
"logprob": -0.7734375, "logprob": -0.7871094,
"special": false, "special": false,
"text": " World" "text": " World"
}, },
{ {
"id": 16013, "id": 16013,
"logprob": -0.61816406, "logprob": -0.64746094,
"special": false, "special": false,
"text": "!\")" "text": "!\")"
}, },
@ -62,7 +62,7 @@
}, },
{ {
"id": 610, "id": 610,
"logprob": -0.4152832, "logprob": -0.41064453,
"special": false, "special": false,
"text": "def" "text": "def"
}, },
@ -92,7 +92,7 @@
}, },
{ {
"id": 444, "id": 444,
"logprob": -0.21618652, "logprob": -0.21655273,
"special": false, "special": false,
"text": "name" "text": "name"
}, },
@ -139,28 +139,16 @@
"text": "Hello" "text": "Hello"
}, },
{ {
"id": 925, "id": 332,
"logprob": -3.3476562, "logprob": -0.034698486,
"special": false, "special": false,
"text": " %" "text": " \""
}, },
{ {
"id": 120, "id": 494,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "s" "text": " +"
},
{
"id": 11571,
"logprob": -0.08892822,
"special": false,
"text": "!\""
},
{
"id": 925,
"logprob": 0.0,
"special": false,
"text": " %"
}, },
{ {
"id": 655, "id": 655,
@ -169,10 +157,22 @@
"text": " name" "text": " name"
}, },
{ {
"id": 46, "id": 494,
"logprob": -0.20141602,
"special": false,
"text": " +"
},
{
"id": 332,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": ")" "text": " \""
},
{
"id": 16013,
"logprob": 0.0,
"special": false,
"text": "!\")"
}, },
{ {
"id": 222, "id": 222,
@ -230,7 +230,7 @@
}, },
{ {
"id": 400, "id": 400,
"logprob": -0.074279785, "logprob": 0.0,
"special": false, "special": false,
"text": "age" "text": "age"
}, },
@ -289,22 +289,34 @@
"text": "Hello" "text": "Hello"
}, },
{ {
"id": 925, "id": 332,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " %" "text": " \""
}, },
{ {
"id": 120, "id": 494,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "s" "text": " +"
}, },
{ {
"id": 49, "id": 655,
"logprob": -0.07891846, "logprob": 0.0,
"special": false, "special": false,
"text": "," "text": " name"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 3021,
"logprob": -0.5761719,
"special": false,
"text": " \","
}, },
{ {
"id": 863, "id": 863,
@ -319,55 +331,43 @@
"text": " are" "text": " are"
}, },
{ {
"id": 925, "id": 332,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " %" "text": " \""
}, },
{ {
"id": 105, "id": 494,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "d" "text": " +"
}, },
{ {
"id": 11339, "id": 615,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " years" "text": " str"
}, },
{ {
"id": 3627, "id": 45,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " old" "text": "("
}, },
{ {
"id": 11571, "id": 400,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "!\"" "text": "age"
}, },
{ {
"id": 925, "id": 46,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " %" "text": ")"
},
{
"id": 327,
"logprob": 0.0,
"special": false,
"text": " ("
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name" "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
} }

View File

@ -0,0 +1,67 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 9,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2684,
"logprob": -0.24902344,
"special": false,
"text": " There"
},
{
"id": 374,
"logprob": -0.0703125,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.23535156,
"special": false,
"text": " a"
},
{
"id": 35372,
"logprob": -0.125,
"special": false,
"text": " statue"
},
{
"id": 304,
"logprob": -0.30273438,
"special": false,
"text": " in"
},
{
"id": 279,
"logprob": -0.20507812,
"special": false,
"text": " the"
},
{
"id": 2217,
"logprob": -0.076171875,
"special": false,
"text": " image"
},
{
"id": 13,
"logprob": -0.053710938,
"special": false,
"text": "."
},
{
"id": 128258,
"logprob": -0.011352539,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " There is a statue in the image."
}

View File

@ -0,0 +1,61 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.118652344,
"special": false,
"text": " A"
},
{
"id": 11426,
"logprob": -0.28320312,
"special": false,
"text": " bee"
},
{
"id": 335,
"logprob": -0.95703125,
"special": false,
"text": " on"
},
{
"id": 253,
"logprob": -0.06982422,
"special": false,
"text": " a"
},
{
"id": 11986,
"logprob": -0.49414062,
"special": false,
"text": " pink"
},
{
"id": 8525,
"logprob": -0.07763672,
"special": false,
"text": " flower"
},
{
"id": 30,
"logprob": -1.0703125,
"special": false,
"text": "."
},
{
"id": 49154,
"logprob": -0.092285156,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " A bee on a pink flower."
}

View File

@ -64,7 +64,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params(
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert ( assert (
response.generated_text response.generated_text
== "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" == "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -0,0 +1,31 @@
import pytest
@pytest.fixture(scope="module")
def flash_idefics3_next_handle(launcher):
with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics3_next(flash_idefics3_next_handle):
await flash_idefics3_next_handle.health(300)
return flash_idefics3_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot):
ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
query = "What is in this image?"
response = await flash_idefics3_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " There is a statue in the image."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 9
assert response == response_snapshot

View File

@ -0,0 +1,31 @@
import pytest
@pytest.fixture(scope="module")
def flash_smolvlm_next_handle(launcher):
with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_smolvlm_next(flash_smolvlm_next_handle):
await flash_smolvlm_next_handle.health(300)
return flash_smolvlm_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot):
ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"
query = "What is in this image?"
response = await flash_smolvlm_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " A bee on a pink flower."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 8
assert response == response_snapshot

View File

@ -1652,7 +1652,11 @@ impl From<&str> for Gpu {
"nvidia-l40s" => Gpu::L40S, "nvidia-l40s" => Gpu::L40S,
"nvidia-a10g" => Gpu::A10G, "nvidia-a10g" => Gpu::A10G,
"nvidia-h100-80gb-hbm3" => Gpu::H100, "nvidia-h100-80gb-hbm3" => Gpu::H100,
"nvidia-h100-nvl" => Gpu::H100,
"nvidia-h100" => Gpu::H100,
"nvidia-a100-sxm4-80gb" => Gpu::A100, "nvidia-a100-sxm4-80gb" => Gpu::A100,
"nvidia-a100-sxm4-40gb" => Gpu::A100,
"nvidia-a100-80gb-pcie" => Gpu::A100,
"nvidia-a100" => Gpu::A100, "nvidia-a100" => Gpu::A100,
card => Gpu::Unknown(card.to_string()), card => Gpu::Unknown(card.to_string()),
} }

View File

@ -17,7 +17,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10" itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.28.0" }
metrics = { workspace = true } metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true } metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
@ -25,7 +25,7 @@ opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" } outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = ["blocking"] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"

View File

@ -110,6 +110,24 @@ pub struct ClipVisionModel {
patch_size: usize, patch_size: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics3 {}
impl Idefics3 {
pub fn get_max_longest_edge(&self) -> usize {
364
}
pub fn get_number_of_features(&self) -> usize {
169
}
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
1456
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 {} pub struct Idefics2 {}
@ -178,6 +196,7 @@ pub enum Config {
Idefics, Idefics,
Mllama, Mllama,
Idefics2(Idefics2), Idefics2(Idefics2),
Idefics3(Idefics3),
Ssm, Ssm,
GptBigcode, GptBigcode,
Granite, Granite,

View File

@ -205,6 +205,7 @@ pub async fn kserve_model_infer(
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: str_input.to_string(), inputs: str_input.to_string(),
parameters: payload.parameters.clone(), parameters: payload.parameters.clone(),
add_special_tokens: true,
}; };
let infer = infer.clone(); let infer = infer.clone();
let compute_type = compute_type.clone(); let compute_type = compute_type.clone();
@ -212,7 +213,7 @@ pub async fn kserve_model_infer(
async move { async move {
generate_internal(infer, compute_type, Json(generate_request), span) generate_internal(infer, compute_type, Json(generate_request), span)
.await .await
.map(|(_, Json(generation))| { .map(|(_, _, Json(generation))| {
let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
OutputChunk { OutputChunk {
name: output.name.clone(), name: output.name.clone(),

View File

@ -170,6 +170,7 @@ impl TokenizerConfigToken {
#[serde(tag = "processor_class")] #[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig { pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor), Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor),
} }
impl HubPreprocessorConfig { impl HubPreprocessorConfig {

View File

@ -7,7 +7,6 @@ use crate::{
use crate::{PyTokenizer, Tokenizer}; use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema};
use outlines_core::json_schema::to_regex as json_schema_to_regex; use outlines_core::json_schema::to_regex as json_schema_to_regex;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
@ -355,9 +354,7 @@ impl Validation {
}?; }?;
// Check if the json is a valid JSONSchema // Check if the json is a valid JSONSchema
JSONSchema::options() jsonschema::draft202012::meta::validate(&json)
.with_draft(Draft::Draft202012)
.compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// The schema can be valid but lack properties. // The schema can be valid but lack properties.
@ -614,6 +611,73 @@ fn image_tokens(
image_string image_string
} }
Idefics3(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
const GLOBAL_IMG: &str = "<global-img>";
let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();
// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
let (height, width) = if height > max_longest_edge_for_image_resize
|| width > max_longest_edge_for_image_resize
{
let aspect_ratio = height as f32 / width as f32;
if height > width {
(
max_longest_edge_for_image_resize,
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
)
} else {
(
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
max_longest_edge_for_image_resize,
)
}
} else {
(height, width)
};
let image_seq_len = config.get_number_of_features();
let max_edge = config.get_max_longest_edge();
let (image_rows, image_cols) = if height > max_edge || width > max_edge {
(
(height as f32 / max_edge as f32).ceil() as usize,
(width as f32 / max_edge as f32).ceil() as usize,
)
} else {
(0, 0)
};
let mut image_string = String::new();
if image_rows == 0 && image_cols == 0 {
// Single image case
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
} else {
// Split image case
for n_h in 0..image_rows {
for n_w in 0..image_cols {
image_string.push_str(FAKE);
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
image_string.push_str(&IMAGE.repeat(image_seq_len));
}
image_string.push('\n');
}
image_string.push('\n');
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
}
image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)), Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!( Qwen2Vl(config) => format!(
@ -647,7 +711,8 @@ fn prepare_input<T: TokenizerTrait>(
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
| Qwen2Vl(_)),
) => { ) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
build-flash-attention-v2-cuda: build-flash-attention-v2-cuda:
pip install -U packaging wheel pip install -U packaging wheel

View File

@ -1,2 +1,5 @@
install-flashinfer: install-flashinfer:
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 # We need fsspec as an additional dependency, but
# `pip install flashinfer` cannot resolve it.
pip install fsspec
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -1,4 +1,4 @@
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247 commit_rocm := de990cd12537f78f74e40b5c8ee1a62d63d734dd
build-vllm-rocm: build-vllm-rocm:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \

28
server/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
[[package]] [[package]]
name = "accelerate" name = "accelerate"
@ -1289,12 +1289,12 @@ files = [
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.6" version = "0.3.7"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.6+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:afedaa9a15e8991442bc8c81f62833fbf5c1556ae9d7a5a9e13b747ce97beef9"}, {file = "marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:bb416d14623dc0ad0eeb2835446c37a41f994555f1baec8701de6d4c1fc17ec8"},
] ]
[package.dependencies] [package.dependencies]
@ -1302,16 +1302,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp310-cp310-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.6" version = "0.3.7"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.6+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:c0c05621d5e87144415d8a6e439072bd844d5f3cb55e4c4c69eabdc4c94610f4"}, {file = "marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:a89bb61d718002d4432158641bce95c6fd68f9ee1a7d5402dd283903397f3185"},
] ]
[package.dependencies] [package.dependencies]
@ -1319,16 +1319,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp311-cp311-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.6" version = "0.3.7"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.6+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:3be4662c8d25a3cdb1793dafe0e2e76dd600913a69a468e2c68d1fed4e149255"}, {file = "marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:ed938d196fc5e9cce9fc44cd2b889d5adc5ca7475c8a23858f1474d29e38bdbf"},
] ]
[package.dependencies] [package.dependencies]
@ -1336,16 +1336,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp312-cp312-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.3.6" version = "0.3.7"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.3.6+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:89eac9d46bc084a256b538afda6053683eb7e505db0e0d4f6dbeca32368caac6"}, {file = "marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:113c54f68565ad476ca12366b4de92131fa3e9ddb16cbe8ad63272972a15ac28"},
] ]
[package.dependencies] [package.dependencies]
@ -1353,7 +1353,7 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]] [[package]]
name = "mdurl" name = "mdurl"
@ -4097,4 +4097,4 @@ torch = ["torch"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.13" python-versions = ">=3.9,<3.13"
content-hash = "c7fdcff2b752cd3beb3995c1ecd15f0f4d9b4e117048b06ab991c6d0e0c86ff3" content-hash = "25f96d5dea777bfa7a959f863e35d2e05e1a6172d0dd45193dbe25ac2f32cc25"

View File

@ -48,10 +48,10 @@ attention-kernels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
marlin-kernels = [ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -60,8 +60,7 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. query,
query.contiguous(),
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
@ -231,8 +230,7 @@ def attention(
softcap = 0.0 softcap = 0.0
return prefill_with_paged_kv_state.get().forward( return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. query,
query.contiguous(),
causal=causal, causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,

View File

@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
page_size: int, page_size: int,
dtype: torch.dtype, kv_dtype: torch.dtype,
q_dtype: torch.dtype,
window_left: int, window_left: int,
): ):
""" """
@ -91,9 +92,10 @@ def use_prefill_with_paged_kv_state(
num_qo_heads=num_heads, num_qo_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_size, head_dim=head_size,
q_data_type=dtype, kv_data_type=kv_dtype,
q_data_type=q_dtype,
page_size=page_size, page_size=page_size,
window_left=window_left, window_left=-1 if window_left is None else window_left,
) )
yield yield
finally: finally:
@ -113,41 +115,6 @@ def create_prefill_state(
) )
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state( def create_decode_state(
*, *,
device: torch.device, device: torch.device,
@ -205,7 +172,7 @@ def use_decode_state(
head_size: int, head_size: int,
page_size: int, page_size: int,
kv_cache_dtype: torch.dtype, kv_cache_dtype: torch.dtype,
dtype: torch.dtype, q_dtype: torch.dtype,
window_left: int, window_left: int,
): ):
""" """
@ -242,8 +209,8 @@ def use_decode_state(
head_dim=head_size, head_dim=head_size,
page_size=page_size, page_size=page_size,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=dtype, q_data_type=q_dtype,
window_left=window_left, window_left=-1 if window_left is None else window_left,
) )
yield yield
finally: finally:

View File

@ -215,7 +215,9 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
)
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -5,27 +5,47 @@ from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from loguru import logger from loguru import logger
import vllm._custom_ops as ops
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE_V1V2 = 512 _PARTITION_SIZE_V1V2 = 1024
_PARTITION_SIZE_CUSTOM = 256 _PARTITION_SIZE_CUSTOM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_MI250_MI300 = any(
arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]
)
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom def _use_rocm_custom_paged_attention(
except ImportError as e: qtype: torch.dtype,
log_master( head_size: int,
logger.info, block_size: int,
f"Custom Paged Attention not available. Complete error: {e}", gqa_ratio: int,
max_seq_len: int,
) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (
use_rocm_custom_paged_attn
and _ON_MI250_MI300
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 131072
) )
use_rocm_custom_paged_attn = False
def paged_attention( def paged_attention(
@ -57,22 +77,50 @@ def paged_attention(
# limitations under the License. # limitations under the License.
# #
if ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out[0]
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache.value.shape[3] # block_size = kv_cache.value.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
num_kv_heads = kv_cache.key.shape[1] num_kv_heads = kv_cache.key.shape[1]
gqa_ratio = num_heads // num_kv_heads gqa_ratio = num_heads // num_kv_heads
use_custom = ( use_custom = _use_rocm_custom_paged_attention(
use_rocm_custom_paged_attn query.dtype, head_size, block_size, gqa_ratio, max_s
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
) )
if not use_custom: if not use_custom:
@ -90,8 +138,6 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
import vllm._custom_ops as ops
use_v1 = ( use_v1 = (
max_s <= 8192 max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512) and (max_num_partitions == 1 or num_seqs * num_heads > 512)
@ -103,7 +149,7 @@ def paged_attention(
query, query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
kv_head_mapping, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -112,6 +158,7 @@ def paged_attention(
None, None,
"auto", "auto",
1.0, 1.0,
1.0,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -137,7 +184,7 @@ def paged_attention(
query, query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
kv_head_mapping, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -146,9 +193,10 @@ def paged_attention(
None, None,
"auto", "auto",
1.0, 1.0,
1.0,
) )
else: else:
paged_attention_custom( ops.paged_attention_rocm(
out, out,
exp_sums, exp_sums,
max_logits, max_logits,
@ -164,6 +212,10 @@ def paged_attention(
max_s, max_s,
None, None,
"auto", "auto",
1.0,
1.0,
None,
_PARTITION_SIZE,
) )
return out return out
@ -232,14 +284,15 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
query, query,
key, # flashdecoding: pass the KV caches, paged: pass the KV.
value, kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_k,
None,
None, None,
None, None,
block_tables if ATTENTION == "flashdecoding" else None,
None, None,
seqlen.max_q, seqlen.max_q,
seqlen.max_k, seqlen.max_k,

View File

@ -72,7 +72,7 @@ if SYSTEM == "cuda":
return normed_hidden_states, residual return normed_hidden_states, residual
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops import vllm._custom_ops as ops
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
@ -121,6 +121,27 @@ class FastRMSNorm(nn.Module):
residual is not None, residual is not None,
) )
return out, residual if residual is not None else hidden_states return out, residual if residual is not None else hidden_states
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
ops.fused_add_rms_norm(
hidden_states,
residual,
self.weight.data,
self.variance_epsilon,
)
return hidden_states, residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
elif hidden_states.shape[-1] > 8192: elif hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
@ -164,20 +185,6 @@ class FastRMSNorm(nn.Module):
res = hidden_states res = hidden_states
return normed_hidden_states, res return normed_hidden_states, res
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."

View File

@ -11,10 +11,10 @@ if SYSTEM == "rocm":
if ROCM_USE_SKINNY_GEMM: if ROCM_USE_SKINNY_GEMM:
try: try:
from vllm import _custom_C import vllm._custom_ops as ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" f"Could not load `vllm._custom_ops` for ROCm skinny gemm. Full error: {e}"
) )
@ -95,12 +95,12 @@ class FastLinearROCm(torch.nn.Module):
out = torch.empty( out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
) )
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count) ops.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192: elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty( out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
) )
_custom_C.LLMM1(weight, inp, out, 4) ops.LLMM1(weight, inp, out, 4)
else: else:
out = F.linear(inp, weight) out = F.linear(inp, weight)

View File

@ -24,10 +24,7 @@ from text_generation_server.utils.weights import (
UnquantizedWeight, UnquantizedWeight,
) )
if SYSTEM == "rocm": if SYSTEM == "ipex":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else: else:
from moe_kernels.fused_moe import fused_topk, grouped_topk from moe_kernels.fused_moe import fused_topk, grouped_topk

View File

@ -1,52 +0,0 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.distributed
# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -6,9 +6,7 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "rocm": if SYSTEM == "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else: else:
from moe_kernels.fused_moe import fused_moe from moe_kernels.fused_moe import fused_moe

View File

@ -7,7 +7,7 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
import rotary_emb import rotary_emb
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops import vllm._custom_ops as ops
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -152,6 +152,9 @@ try:
from text_generation_server.models.custom_modeling.idefics2 import ( from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration, Idefics2ForConditionalGeneration,
) )
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import ( from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
) )
@ -188,6 +191,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True, "multimodal": True,
} }
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = { LLAVA_NEXT = {
"type": "llava_next", "type": "llava_next",
"name": "Llava Next (1.6)", "name": "Llava Next (1.6)",
@ -1253,6 +1262,24 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS3:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA: if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(

View File

@ -75,7 +75,7 @@ class CohereRotary(PositionRotaryEmbedding):
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops import vllm._custom_ops as ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

View File

@ -23,9 +23,7 @@ from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm": if SYSTEM == "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else: else:
from moe_kernels.fused_moe import fused_moe from moe_kernels.fused_moe import fused_moe

View File

@ -43,9 +43,9 @@ from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
from vllm import _custom_C import vllm._custom_ops as ops
except Exception as e: except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
@ -408,7 +408,7 @@ class DeepseekV2MLP(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce) return self.down_proj(out, reduce=reduce)
else: else:
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)

View File

@ -91,7 +91,7 @@ class GPTJRotary(PositionRotaryEmbedding):
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops import vllm._custom_ops as ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

View File

@ -64,9 +64,9 @@ if SYSTEM != "ipex":
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
from vllm import _custom_C import vllm._custom_ops as ops
except Exception as e: except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
def load_attention(config, prefix: str, weights, layer_id): def load_attention(config, prefix: str, weights, layer_id):
@ -392,7 +392,7 @@ class LlamaMLP(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu( ops.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
) )
return self.down_proj(out, adapter_data) return self.down_proj(out, adapter_data)
@ -515,9 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=0, index=0,
prefix=( prefix=f"{prefix}.layers.0",
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -533,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaCrossLayer( FlashLlamaCrossLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -546,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -561,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=last_layer_id, index=last_layer_id,
prefix=( prefix=(f"{prefix}.layers.{last_layer_id}"),
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", prefix=f"{prefix}.norm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
@ -629,19 +615,24 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__() super().__init__()
with no_fp8(weights): with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" f"{name}.embed_tokens"
if not prefix if not prefix
else f"{prefix}.model.embed_tokens" else f"{prefix}.{name}.embed_tokens"
), ),
weights=weights, weights=weights,
) )
self.model = FlashLlamaModel(prefix, config, weights) self.model = FlashLlamaModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
suffix = "model.embed_tokens" suffix = "model.embed_tokens"
else: else:
@ -652,11 +643,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if embedding_multiplier is not None: if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier self.embed_tokens.weight.data *= embedding_multiplier
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.{suffix}", prefix,
weights=weights, weights,
) )
# Used in Granite # Used in Granite

View File

@ -49,9 +49,9 @@ from text_generation_server.layers.layernorm import (
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
from vllm import _custom_C import vllm._custom_ops as ops
except Exception as e: except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
@ -318,7 +318,7 @@ class MistralMLP(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu( ops.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
) )
return self.down_proj(out, adapter_data) return self.down_proj(out, adapter_data)

View File

@ -0,0 +1,584 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics3 model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics3VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics3VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics3VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics3EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics3VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics3VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics3Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics3EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics3VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics3VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics3Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics3SimpleMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
output_size = config.text_config.hidden_size
proj = nn.Parameter(
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
requires_grad=False,
).to(weights.dtype)
self.proj = nn.Linear(input_size, output_size, bias=False)
self.proj.weight = proj
def forward(self, x):
return self.proj(x)
class Idefics3Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
self.scale_factor = config.scale_factor
def pixel_shuffle(self, x, scale_factor=2):
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3)
x = x.reshape(
bsz,
int(width / scale_factor),
int(height / scale_factor),
embed_dim * (scale_factor**2),
)
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
return x
def forward(self, image_hidden_states):
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class Idefics3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
# since Idefics3 uses the `embed_tokens` for the final prediction
# config.text_config.tie_word_embeddings = True
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics3VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
config.quantize = None
self.connector = Idefics3Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -52,7 +52,7 @@ from loguru import logger
if SYSTEM == "cuda": if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops import vllm._custom_ops as ops
else: else:
dropout_layer_norm = None dropout_layer_norm = None

View File

@ -450,7 +450,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
width //= self.spatial_merge_size width //= self.spatial_merge_size
# calculate the length of the text and image tokens # calculate the length of the text and image tokens
text_length = next_image_pos - current_pos text_length = next_image_pos
start_idx = ( start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
) )
@ -480,7 +480,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
llm_pos_ids_list.append(image_pos_ids) llm_pos_ids_list.append(image_pos_ids)
current_pos = next_image_pos + time_steps * height * width current_pos += next_image_pos + time_steps * height * width
image_index += 1 image_index += 1
if current_pos < batch_input_ids.size(1): if current_pos < batch_input_ids.size(1):

View File

@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
return FlashLlamaForCausalLM(prefix, config, weights) return FlashLlamaForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "mistral": elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,

View File

@ -1288,7 +1288,7 @@ class FlashCausalLM(Model):
weights_loader=weights_loader, weights_loader=weights_loader,
) )
prefix = "" prefix = None
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -1663,7 +1663,7 @@ class FlashCausalLM(Model):
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen, max_total_tokens)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
@ -1710,7 +1710,7 @@ class FlashCausalLM(Model):
assert max_total_tokens is not None assert max_total_tokens is not None
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def tunableop_warmup(self, seqlen: int): def tunableop_warmup(self, seqlen: int, max_bt: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
@ -1724,11 +1724,15 @@ class FlashCausalLM(Model):
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
) )
max_s = seqlen max_s = seqlen
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(seqlen)
block_tables = block_tables.reshape((seqlen, max_bt))
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen, max_k=seqlen,
) )
@ -1738,7 +1742,7 @@ class FlashCausalLM(Model):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=block_tables,
seqlen=seqlen, seqlen=seqlen,
slots=slots, slots=slots,
max_s=max_s, max_s=max_s,
@ -2480,7 +2484,8 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
dtype=self.dtype, kv_dtype=self.kv_cache_dtype,
q_dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )
else: else:
@ -2494,6 +2499,6 @@ class FlashCausalLM(Model):
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype, q_dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )

View File

@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from loguru import logger
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -23,6 +24,40 @@ tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>" IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>" IDEFICS2_IMAGE_TOKEN = "<image>"
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image(
*,
image_seq_len: int,
image_rows: int,
image_cols: int,
fake_token_around_image: str,
image_token: str,
global_img_token: str,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
@ -54,10 +89,26 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
if processor.image_processor.do_image_splitting: if processor.image_processor.do_image_splitting:
image_str *= 5 image_str *= 5
return image_str return image_str
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
)
image_str = _prompt_split_image(
image_seq_len=image_seq_len,
image_rows=n_rows,
image_cols=n_cols,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger
log_master( log_master(
logger.info, logger.info,
@ -68,7 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens return "<image>" * config.text_config.num_image_tokens
elif config.model_type == "qwen2_vl": elif config.model_type == "qwen2_vl":
num_pads = image_input.pixel_values.shape[0] // 4 grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>" return f"<|vision_start|>{padding}<|vision_end|>"
else: else:
@ -193,12 +245,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images: if images:
image_inputs = processor.image_processor(images, return_tensors="pt") kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else: else:
image_inputs = None image_inputs = None
batch_inputs = [] batch_tokenized_inputs = []
max_truncation = 0 max_length = 0
image_id = 0 image_id = 0
for r in requests: for r in requests:
full_text = "" full_text = ""
@ -213,16 +274,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_id += 1 image_id += 1
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer(
batch_inputs.append(full_text) full_text,
max_truncation = max(max_truncation, r.truncate) truncation=True,
max_length=r.truncate,
batch_tokenized_inputs = tokenizer( add_special_tokens=r.add_special_tokens,
batch_inputs, )["input_ids"]
truncation=True, max_length = max(max_length, len(input_ids))
max_length=max_truncation, batch_tokenized_inputs.append(input_ids)
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs

View File

@ -1,6 +1,6 @@
import os import os
import torch import torch
from torch.distributed import ProcessGroup
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -18,10 +18,11 @@ class FakeBarrier:
pass pass
class FakeGroup: class FakeGroup(ProcessGroup):
def __init__(self, rank, size): def __init__(self, rank, size):
self._rank = rank self._rank = rank
self._size = size self._size = size
super().__init__(rank, size)
def allreduce(self, *args, **kwargs): def allreduce(self, *args, **kwargs):
return FakeBarrier() return FakeBarrier()