diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c0199a66..c43d8eb9 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -67,7 +67,8 @@ jobs: export label_extension="-rocm" export docker_devices="/dev/kfd,/dev/dri" export docker_volume="/mnt" - export runs_on="amd-gpu-runners" + # This runner was deactivated. + export runs_on="ubuntu-latest" export platform="" export extra_pytest="-k test_flash_gemma_gptq_load" ;; diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4eeca334..6bcf7d96 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -31,7 +31,7 @@ jobs: with: # Released on: 02 May, 2024 # https://releases.rs/docs/1.78.0/ - toolchain: 1.80.0 + toolchain: 1.84.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/Cargo.lock b/Cargo.lock index 74ae6e16..e63d1540 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -456,18 +456,18 @@ dependencies = [ [[package]] name = "bit-set" -version = "0.5.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ "bit-vec", ] [[package]] name = "bit-vec" -version = "0.6.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bit_field" @@ -502,6 +502,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "borrow-or-share" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32" + [[package]] name = "built" version = "0.7.5" @@ -1139,6 +1145,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "encode_unicode" version = "0.3.6" @@ -1196,12 +1211,13 @@ dependencies = [ [[package]] name = "fancy-regex" -version = "0.11.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" dependencies = [ "bit-set", - "regex", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -1247,6 +1263,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" +[[package]] +name = "fluent-uri" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5" +dependencies = [ + "borrow-or-share", + "ref-cast", + "serde", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1285,9 +1312,9 @@ dependencies = [ [[package]] name = "fraction" -version = "0.13.1" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" +checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7" dependencies = [ "lazy_static", "num", @@ -1414,10 +1441,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", ] [[package]] @@ -1573,7 +1598,7 @@ dependencies = [ "native-tls", "num_cpus", "rand", - "reqwest", + "reqwest 0.11.27", "serde", "serde_json", "thiserror", @@ -2051,15 +2076,6 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" -[[package]] -name = "iso8601" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153" -dependencies = [ - "nom", -] - [[package]] name = "itertools" version = "0.10.5" @@ -2128,32 +2144,27 @@ dependencies = [ [[package]] name = "jsonschema" -version = "0.17.1" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" +checksum = "74d8eb539cdb4222da29bb658cc9881aa2477b33fb1a74c5c31450395fc1a4b2" dependencies = [ "ahash", - "anyhow", - "base64 0.21.7", + "base64 0.22.1", "bytecount", - "clap 4.5.21", + "email_address", "fancy-regex", "fraction", - "getrandom", - "iso8601", + "idna", "itoa", - "memchr", "num-cmp", "once_cell", - "parking_lot", "percent-encoding", - "regex", - "reqwest", + "referencing", + "regex-syntax 0.8.5", + "reqwest 0.12.9", "serde", "serde_json", - "time", - "url", - "uuid", + "uuid-simd", ] [[package]] @@ -2984,6 +2995,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "outref" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" + [[package]] name = "overload" version = "0.1.1" @@ -3557,6 +3574,39 @@ dependencies = [ "thiserror", ] +[[package]] +name = "ref-cast" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "referencing" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093a875008827c0ae15c746189966e162faa05bf347719d06302c548ac63630f" +dependencies = [ + "ahash", + "fluent-uri", + "once_cell", + "percent-encoding", + "serde_json", +] + [[package]] name = "regex" version = "1.11.1" @@ -3641,6 +3691,42 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.1", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows-registry", +] + [[package]] name = "rgb" version = "0.8.50" @@ -4220,6 +4306,9 @@ name = "sync_wrapper" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] [[package]] name = "synstructure" @@ -4404,7 +4493,7 @@ dependencies = [ "once_cell", "pyo3", "regex", - "reqwest", + "reqwest 0.11.27", "serde", "serde_json", "thiserror", @@ -4445,7 +4534,7 @@ dependencies = [ "pyo3", "rand", "regex", - "reqwest", + "reqwest 0.11.27", "serde", "serde_json", "sysinfo", @@ -4493,7 +4582,7 @@ dependencies = [ "prost-build", "rand", "regex", - "reqwest", + "reqwest 0.11.27", "serde", "serde_json", "slotmap", @@ -4544,7 +4633,7 @@ dependencies = [ "prost-build", "rand", "regex", - "reqwest", + "reqwest 0.11.27", "serde", "serde_json", "slotmap", @@ -5298,6 +5387,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "uuid-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8" +dependencies = [ + "outref", + "uuid", + "vsimd", +] + [[package]] name = "v_frame" version = "0.3.8" @@ -5349,6 +5449,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "walkdir" version = "2.5.0" @@ -5558,6 +5664,36 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.45.0" diff --git a/Dockerfile b/Dockerfile index 0c08d48f..0f2ae6cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/Dockerfile_amd b/Dockerfile_amd index dc748f49..92acff5a 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -268,6 +268,15 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build +FROM kernel-builder AS marlin-kernels +WORKDIR /usr/src +ENV MARLIN_KERNELS_BRANCH=v0.3.6 +ENV VLLM_TARGET_DEVICE=rocm +RUN git clone https://github.com/danieldk/marlin-kernels.git && \ + cd marlin-kernels && \ + git checkout ${MARLIN_KERNELS_BRANCH} && \ + python setup.py install + FROM kernel-builder AS moe-kernels WORKDIR /usr/src ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd @@ -299,6 +308,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 # Copy build artifacts from exllamav2 kernels builder COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from marlin kernels +COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages + # Copy build artifacts from moe kernels COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages diff --git a/Dockerfile_intel b/Dockerfile_intel index e024f31a..2b41fd8b 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,6 +1,6 @@ ARG PLATFORM=xpu -FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -97,11 +97,10 @@ ENV HF_HOME=/data \ WORKDIR /usr/src -RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir -RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir -RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir -RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir -RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install triton-xpu==3.0.0b2 --no-cache-dir @@ -119,6 +118,9 @@ ENV CCL_ZE_IPC_EXCHANGE=sockets #ENV TORCH_LLM_ALLREDUCE=1 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6 +RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/Dockerfile_trtllm b/Dockerfile_trtllm index b4523ea5..e6a16ecc 100644 --- a/Dockerfile_trtllm +++ b/Dockerfile_trtllm @@ -2,7 +2,7 @@ ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" ARG OMPI_VERSION="4.1.7rc1" # Build dependencies resolver stage -FROM lukemathwalker/cargo-chef:latest AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef WORKDIR /usr/src/text-generation-inference/backends/trtllm FROM chef AS planner diff --git a/README.md b/README.md index 6d3a9b12..9842a2a7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
- Making TGI deployment optimal + Making TGI deployment optimal # Text Generation Inference @@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ -3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model ``` And then you can make requests like @@ -141,8 +141,8 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri For example, if you want to serve the gated Llama V2 model variants: 1. Go to https://huggingface.co/settings/tokens -2. Copy your cli READ token -3. Export `HF_TOKEN=` +2. Copy your CLI READ token +3. Export `HF_TOKEN=` or with Docker: @@ -151,13 +151,14 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= -docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model +docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model ``` ### A note on Shared Memory (shm) [`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by -`PyTorch` to do distributed training/inference. `text-generation-inference` make +`PyTorch` to do distributed training/inference. `text-generation-inference` makes use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models. In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if @@ -196,7 +197,7 @@ Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with T You can also opt to install `text-generation-inference` locally. -First clone the repository and change directoy into it: +First clone the repository and change directory into it: ```shell git clone https://github.com/huggingface/text-generation-inference @@ -213,7 +214,7 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh conda create -n text-generation-inference python=3.11 conda activate text-generation-inference -#using pyton venv +#using python venv python3 -m venv .venv source .venv/bin/activate ``` diff --git a/backends/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs index 3068a61c..822b0307 100644 --- a/backends/grpc-metadata/src/lib.rs +++ b/backends/grpc-metadata/src/lib.rs @@ -8,7 +8,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// Inject context in the metadata of a gRPC request. struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); -impl<'a> Injector for MetadataInjector<'a> { +impl Injector for MetadataInjector<'_> { /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs fn set(&mut self, key: &str, value: String) { if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml index 4d32474e..0decf41a 100644 --- a/backends/v2/Cargo.toml +++ b/backends/v2/Cargo.toml @@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] } grpc-metadata = { path = "../grpc-metadata" } futures = "0.3.28" hf-hub = { workspace = true } -jsonschema = { version = "0.17.1", features = ["draft202012"] } +jsonschema = { version = "0.28.0" } metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index 61a3eebc..c9a9335d 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -213,8 +213,7 @@ impl State { } // Pad prefill_token_budget to be a multiple of block size - let prefill_token_budget = - ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; + let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); @@ -245,9 +244,8 @@ impl State { prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { // pad to block size - prefill_tokens += ((entry.request.input_length + self.block_size - 1) - / self.block_size) - * self.block_size; + prefill_tokens += + entry.request.input_length.div_ceil(self.block_size) * self.block_size; } if self.requires_padding { @@ -262,8 +260,7 @@ impl State { }; // pad to block size - decode_tokens += - ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; + decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size; } if prefill_tokens > prefill_token_budget diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 69dad072..996290ed 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] } grpc-metadata = { path = "../grpc-metadata" } futures = "0.3.28" hf-hub = { workspace = true } -jsonschema = { version = "0.17.1", features = ["draft202012"] } +jsonschema = { version = "0.28.0" } metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 4fea172b..e7f3d85a 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -165,13 +165,13 @@ impl Allocator for SimpleAllocator { let (tokens, repeats) = match self.window_size { None => (tokens, 1), Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; + let repeats = tokens.div_ceil(window_size); let tokens = core::cmp::min(tokens, window_size); (tokens, repeats as usize) } }; // Pad to a multiple of block size - let required_blocks = (tokens + self.block_size - 1) / self.block_size; + let required_blocks = tokens.div_ceil(self.block_size); (required_blocks, repeats) }; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index dd27806f..249eebf7 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -257,8 +257,7 @@ impl State { } // Pad prefill_token_budget to be a multiple of block size - let prefill_token_budget = - ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; + let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 8a544891..532ec6dd 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -103,7 +103,7 @@ impl Allocator for RadixAllocator { let prefix_len = blocks.len() * self.block_size as usize; let suffix_len = tokens - prefix_len as u32; - let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + let suffix_blocks = suffix_len.div_ceil(self.block_size); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4503424b..8fcba516 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -13,6 +13,8 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_tpu + title: Using TGI with Google TPUs - local: installation_intel title: Using TGI with Intel GPUs - local: installation diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md index 2d55c952..e389fbbc 100644 --- a/docs/source/basic_tutorials/using_guidance.md +++ b/docs/source/basic_tutorials/using_guidance.md @@ -187,8 +187,6 @@ In addition to the grammar parameter, we've also introduced a set of tools and f Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. -Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. - ```json curl localhost:3000/v1/chat/completions \ -X POST \ diff --git a/docs/source/conceptual/speculation.md b/docs/source/conceptual/speculation.md index 45618ae3..74e010c8 100644 --- a/docs/source/conceptual/speculation.md +++ b/docs/source/conceptual/speculation.md @@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models: - [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) -In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md) +In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. Read for more in [Train Medusa](../basic_tutorials/train_medusa#training). In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. diff --git a/docs/source/installation_tpu.md b/docs/source/installation_tpu.md new file mode 100644 index 00000000..559e83aa --- /dev/null +++ b/docs/source/installation_tpu.md @@ -0,0 +1,3 @@ +# Using TGI with Google TPUs + +Check out this [guide](https://huggingface.co/docs/optimum-tpu) on how to serve models with TGI on TPUs. diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 0f39ff28..5ac90351 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) +- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) diff --git a/flake.lock b/flake.lock index ec87d569..23e76b8b 100644 --- a/flake.lock +++ b/flake.lock @@ -108,11 +108,11 @@ "pre-commit-hooks": "pre-commit-hooks_3" }, "locked": { - "lastModified": 1732039290, - "narHash": "sha256-LQKY7bShf2H9kJouxa9ZspfdrulnZF9o4kLTqGqCDYM=", + "lastModified": 1734429562, + "narHash": "sha256-V2XNs3Ir8WXNHdocfzkR/fu0FzkZ9uTDJkVecxJrGmQ=", "owner": "nix-community", "repo": "crate2nix", - "rev": "9ff208ce7f5a482272b1bcefbe363c772d7ff914", + "rev": "8537c2d7cb623679aaeff62c4c4c43a91566ab09", "type": "github" }, "original": { @@ -853,11 +853,11 @@ ] }, "locked": { - "lastModified": 1732242723, - "narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=", + "lastModified": 1736907983, + "narHash": "sha256-fw55wVwpJW36Md2HZBKuxX3YHGeqsGsspPLtCMVr1Y8=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a", + "rev": "eaa365c911441e07e387ff6acc596619fc50b156", "type": "github" }, "original": { @@ -978,11 +978,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1732218602, - "narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=", + "lastModified": 1736436388, + "narHash": "sha256-CIyxVPpM9RrSwthNT/4DQ10YPk/uwzP7AeE83kBNsrE=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "f79638ac4e420e661321261744e745a3a747e182", + "rev": "5103c3fb1f9ad1fd33b6e09ff05e957884b112d5", "type": "github" }, "original": { diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c9c47766..c702ae70 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -354,6 +354,7 @@ def launcher(event_loop): kv_cache_dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, + max_input_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, lora_adapters: Optional[List[str]] = None, @@ -402,6 +403,9 @@ def launcher(event_loop): if max_input_length: args.append("--max-input-length") args.append(str(max_input_length)) + if max_input_tokens: + args.append("--max-input-tokens") + args.append(str(max_input_tokens)) if max_batch_prefill_tokens: args.append("--max-batch-prefill-tokens") args.append(str(max_batch_prefill_tokens)) diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json index 7d35e8f9..771708eb 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json @@ -32,7 +32,7 @@ }, { "id": 1101, - "logprob": -1.0947266, + "logprob": -1.0136719, "special": false, "text": " also" }, @@ -56,13 +56,13 @@ }, { "id": 4009, - "logprob": -0.15563965, + "logprob": -0.21923828, "special": false, "text": " network" }, { "id": 477, - "logprob": -1.4003906, + "logprob": -1.4824219, "special": false, "text": " or" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json index 0db48f3e..6b3f5092 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json @@ -8,7 +8,7 @@ "tokens": [ { "id": 1939, - "logprob": -2.2675781, + "logprob": -2.2460938, "special": false, "text": "?\n\n" }, @@ -20,13 +20,13 @@ }, { "id": 20909, - "logprob": -0.37695312, + "logprob": -0.48608398, "special": false, "text": " Learning" }, { "id": 4102, - "logprob": -1.9316406, + "logprob": -2.265625, "special": false, "text": " " }, @@ -38,25 +38,13 @@ }, { "id": 458, - "logprob": -0.80859375, + "logprob": -0.6328125, "special": false, "text": " an" }, - { - "id": 3082, - "logprob": -1.4541016, - "special": false, - "text": " area" - }, - { - "id": 315, - "logprob": 0.0, - "special": false, - "text": " of" - }, { "id": 20443, - "logprob": -0.5136719, + "logprob": -0.1796875, "special": false, "text": " artificial" }, @@ -65,9 +53,21 @@ "logprob": 0.0, "special": false, "text": " intelligence" + }, + { + "id": 320, + "logprob": -0.37695312, + "special": false, + "text": " (" + }, + { + "id": 15469, + "logprob": 0.0, + "special": false, + "text": "AI" } ], "top_tokens": null }, - "generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" + "generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json index abcaf876..1fa4e33a 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json @@ -9,61 +9,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } @@ -82,61 +82,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } @@ -155,61 +155,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } @@ -228,61 +228,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json index 08c63e79..29709676 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json @@ -44,7 +44,7 @@ }, { "id": 38397, - "logprob": -0.12695312, + "logprob": 0.0, "special": false, "text": " subset" }, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json index 6306f75e..0f54bbe8 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -14,60 +14,60 @@ }, { "id": 573, - "logprob": -0.18493652, + "logprob": -0.19030762, "special": false, "text": " the" }, { "id": 16819, - "logprob": -1.4804688, + "logprob": -1.4863281, "special": false, "text": " detection" }, { "id": 576, - "logprob": -0.7011719, + "logprob": -0.7089844, + "special": false, + "text": " of" + }, + { + "id": 573, + "logprob": -2.0410156, + "special": false, + "text": " the" + }, + { + "id": 8566, + "logprob": 0.0, + "special": false, + "text": " presence" + }, + { + "id": 689, + "logprob": -0.16491699, + "special": false, + "text": " or" + }, + { + "id": 14862, + "logprob": 0.0, + "special": false, + "text": " absence" + }, + { + "id": 576, + "logprob": -0.9970703, "special": false, "text": " of" }, { "id": 671, - "logprob": -2.1738281, + "logprob": -0.5292969, "special": false, "text": " an" - }, - { - "id": 24646, - "logprob": -3.0449219, - "special": false, - "text": " RNA" - }, - { - "id": 12369, - "logprob": -0.19299316, - "special": false, - "text": " virus" - }, - { - "id": 575, - "logprob": -0.10632324, - "special": false, - "text": " in" - }, - { - "id": 6022, - "logprob": -0.98095703, - "special": false, - "text": " patients" - }, - { - "id": 1064, - "logprob": -1.3095703, - "special": false, - "text": " who" } ], "top_tokens": null }, - "generated_text": "Test request for the detection of an RNA virus in patients who" + "generated_text": "Test request for the detection of the presence or absence of an" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json index 914e59c0..6674cf50 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -8,7 +8,7 @@ "tokens": [ { "id": 2284, - "logprob": -0.296875, + "logprob": -0.31323242, "special": false, "text": "():" }, @@ -38,13 +38,13 @@ }, { "id": 10914, - "logprob": -0.7734375, + "logprob": -0.7871094, "special": false, "text": " World" }, { "id": 16013, - "logprob": -0.61816406, + "logprob": -0.64746094, "special": false, "text": "!\")" }, @@ -62,7 +62,7 @@ }, { "id": 610, - "logprob": -0.4152832, + "logprob": -0.41064453, "special": false, "text": "def" }, @@ -92,7 +92,7 @@ }, { "id": 444, - "logprob": -0.21618652, + "logprob": -0.21655273, "special": false, "text": "name" }, @@ -139,28 +139,16 @@ "text": "Hello" }, { - "id": 925, - "logprob": -3.3476562, + "id": 332, + "logprob": -0.034698486, "special": false, - "text": " %" + "text": " \"" }, { - "id": 120, + "id": 494, "logprob": 0.0, "special": false, - "text": "s" - }, - { - "id": 11571, - "logprob": -0.08892822, - "special": false, - "text": "!\"" - }, - { - "id": 925, - "logprob": 0.0, - "special": false, - "text": " %" + "text": " +" }, { "id": 655, @@ -169,10 +157,22 @@ "text": " name" }, { - "id": 46, + "id": 494, + "logprob": -0.20141602, + "special": false, + "text": " +" + }, + { + "id": 332, "logprob": 0.0, "special": false, - "text": ")" + "text": " \"" + }, + { + "id": 16013, + "logprob": 0.0, + "special": false, + "text": "!\")" }, { "id": 222, @@ -230,7 +230,7 @@ }, { "id": 400, - "logprob": -0.074279785, + "logprob": 0.0, "special": false, "text": "age" }, @@ -289,22 +289,34 @@ "text": "Hello" }, { - "id": 925, + "id": 332, "logprob": 0.0, "special": false, - "text": " %" + "text": " \"" }, { - "id": 120, + "id": 494, "logprob": 0.0, "special": false, - "text": "s" + "text": " +" }, { - "id": 49, - "logprob": -0.07891846, + "id": 655, + "logprob": 0.0, "special": false, - "text": "," + "text": " name" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 3021, + "logprob": -0.5761719, + "special": false, + "text": " \"," }, { "id": 863, @@ -319,55 +331,43 @@ "text": " are" }, { - "id": 925, + "id": 332, "logprob": 0.0, "special": false, - "text": " %" + "text": " \"" }, { - "id": 105, + "id": 494, "logprob": 0.0, "special": false, - "text": "d" + "text": " +" }, { - "id": 11339, + "id": 615, "logprob": 0.0, "special": false, - "text": " years" + "text": " str" }, { - "id": 3627, + "id": 45, "logprob": 0.0, "special": false, - "text": " old" + "text": "(" }, { - "id": 11571, + "id": 400, "logprob": 0.0, "special": false, - "text": "!\"" + "text": "age" }, { - "id": 925, + "id": 46, "logprob": 0.0, "special": false, - "text": " %" - }, - { - "id": 327, - "logprob": 0.0, - "special": false, - "text": " (" - }, - { - "id": 444, - "logprob": 0.0, - "special": false, - "text": "name" + "text": ")" } ], "top_tokens": null }, - "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name" + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json new file mode 100644 index 00000000..1bc1e0fd --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.9355469, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40795898, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4599609, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.625, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23242188, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json new file mode 100644 index 00000000..ce3831b0 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json @@ -0,0 +1,373 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [], + "seed": 0, + "tokens": [ + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -0.7944336, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -0.1796875, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": 0.0, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": 0.0, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": 0.0, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": 0.0, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": 0.0, + "special": false, + "text": "slide" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 700, + "logprob": 0.0, + "special": false, + "text": "type" + }, + { + "id": 582, + "logprob": 0.0, + "special": false, + "text": "\":" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 7277, + "logprob": -0.06994629, + "special": false, + "text": "slide" + }, + { + "id": 3667, + "logprob": 0.0, + "special": false, + "text": "\"}" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": 0.0, + "special": false, + "text": "#" + }, + { + "id": 607, + "logprob": -0.8261719, + "special": false, + "text": " #" + }, + { + "id": 244, + "logprob": -1.8574219, + "special": false, + "text": " " + }, + { + "id": 55, + "logprob": -1.4541016, + "special": false, + "text": "2" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 6208, + "logprob": -0.9794922, + "special": false, + "text": " What" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 341, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 10609, + "logprob": -0.69189453, + "special": false, + "text": " difference" + }, + { + "id": 3761, + "logprob": 0.0, + "special": false, + "text": " between" + }, + { + "id": 331, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 1168, + "logprob": -0.27172852, + "special": false, + "text": " list" + }, + { + "id": 480, + "logprob": 0.0, + "special": false, + "text": " and" + }, + { + "id": 331, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 8871, + "logprob": 0.0, + "special": false, + "text": " tuple" + }, + { + "id": 68, + "logprob": 0.0, + "special": false, + "text": "?" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -1.3359375, + "special": false, + "text": "#" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": 0.0, + "special": false, + "text": "#" + }, + { + "id": 449, + "logprob": -0.03164673, + "special": false, + "text": " -" + }, + { + "id": 418, + "logprob": -1.0947266, + "special": false, + "text": " A" + }, + { + "id": 1168, + "logprob": 0.0, + "special": false, + "text": " list" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 331, + "logprob": -0.3305664, + "special": false, + "text": " a" + }, + { + "id": 14792, + "logprob": 0.0, + "special": false, + "text": " mutable" + }, + { + "id": 6645, + "logprob": -0.40478516, + "special": false, + "text": " sequence" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 4725, + "logprob": -0.50390625, + "special": false, + "text": " elements" + }, + { + "id": 49, + "logprob": -2.1269531, + "special": false, + "text": "," + }, + { + "id": 2236, + "logprob": -0.1427002, + "special": false, + "text": " while" + }, + { + "id": 331, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 8871, + "logprob": 0.0, + "special": false, + "text": " tuple" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 619, + "logprob": 0.0, + "special": false, + "text": " an" + }, + { + "id": 26079, + "logprob": 0.0, + "special": false, + "text": " immutable" + }, + { + "id": 6645, + "logprob": 0.0, + "special": false, + "text": " sequence" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 4725, + "logprob": 0.0, + "special": false, + "text": " elements" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": 0.0, + "special": false, + "text": "#" + }, + { + "id": 449, + "logprob": 0.0, + "special": false, + "text": " -" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json new file mode 100644 index 00000000..bf9e3010 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json @@ -0,0 +1,294 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json new file mode 100644 index 00000000..de76dd50 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json @@ -0,0 +1,71 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 100, + "logprob": -0.9824219, + "special": false, + "text": "_" + }, + { + "id": 5879, + "logprob": -0.3017578, + "special": false, + "text": "world" + }, + { + "id": 2284, + "logprob": -0.68652344, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.27734375, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.4482422, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.4296875, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.8544922, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7573242, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.81347656, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": "_world():\n print(\"Hello World!\")\n" +} diff --git a/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json new file mode 100644 index 00000000..6bf2b93a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json @@ -0,0 +1,67 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 9, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2684, + "logprob": -0.24902344, + "special": false, + "text": " There" + }, + { + "id": 374, + "logprob": -0.0703125, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.23535156, + "special": false, + "text": " a" + }, + { + "id": 35372, + "logprob": -0.125, + "special": false, + "text": " statue" + }, + { + "id": 304, + "logprob": -0.30273438, + "special": false, + "text": " in" + }, + { + "id": 279, + "logprob": -0.20507812, + "special": false, + "text": " the" + }, + { + "id": 2217, + "logprob": -0.076171875, + "special": false, + "text": " image" + }, + { + "id": 13, + "logprob": -0.053710938, + "special": false, + "text": "." + }, + { + "id": 128258, + "logprob": -0.011352539, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " There is a statue in the image." +} diff --git a/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json new file mode 100644 index 00000000..17a69d0d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json @@ -0,0 +1,61 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 8, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.118652344, + "special": false, + "text": " A" + }, + { + "id": 11426, + "logprob": -0.28320312, + "special": false, + "text": " bee" + }, + { + "id": 335, + "logprob": -0.95703125, + "special": false, + "text": " on" + }, + { + "id": 253, + "logprob": -0.06982422, + "special": false, + "text": " a" + }, + { + "id": 11986, + "logprob": -0.49414062, + "special": false, + "text": " pink" + }, + { + "id": 8525, + "logprob": -0.07763672, + "special": false, + "text": " flower" + }, + { + "id": 30, + "logprob": -1.0703125, + "special": false, + "text": "." + }, + { + "id": 49154, + "logprob": -0.092285156, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " A bee on a pink flower." +} diff --git a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py index 7cc82a4e..a0b0416b 100644 --- a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py +++ b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py @@ -64,7 +64,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params( assert response.details.generated_tokens == 10 assert ( response.generated_text - == "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" + == "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" ) assert response == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder2_lora.py b/integration-tests/models/test_flash_starcoder2_lora.py new file mode 100644 index 00000000..6480f669 --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2_lora.py @@ -0,0 +1,79 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher( + "bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"] + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "who are you?", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "who are you?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_with_hugcode_adapter( + flash_starcoder2, response_snapshot +): + response = requests.post( + f"{flash_starcoder2.base_url}/generate", + headers=flash_starcoder2.headers, + json={ + "inputs": "def print_hello", + "parameters": { + "max_new_tokens": 10, + "adapter_id": "smangrul/starcoder-3b-hugcoder", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generated_text"] == '_world():\n print("Hello World!")\n' + + assert data == response_snapshot diff --git a/integration-tests/models/test_idefics3.py b/integration-tests/models/test_idefics3.py new file mode 100644 index 00000000..80be2350 --- /dev/null +++ b/integration-tests/models/test_idefics3.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_idefics3_next_handle(launcher): + with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_idefics3_next(flash_idefics3_next_handle): + await flash_idefics3_next_handle.health(300) + return flash_idefics3_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot): + ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + query = "What is in this image?" + response = await flash_idefics3_next.generate( + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, + ) + print(response) + assert ( + response.generated_text == " There is a statue in the image." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 9 + assert response == response_snapshot diff --git a/integration-tests/models/test_smolvlm.py b/integration-tests/models/test_smolvlm.py new file mode 100644 index 00000000..cd105d84 --- /dev/null +++ b/integration-tests/models/test_smolvlm.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_smolvlm_next_handle(launcher): + with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_smolvlm_next(flash_smolvlm_next_handle): + await flash_smolvlm_next_handle.health(300) + return flash_smolvlm_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot): + ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg" + query = "What is in this image?" + response = await flash_smolvlm_next.generate( + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, + ) + print(response) + assert ( + response.generated_text == " A bee on a pink flower." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 8 + assert response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fb6ba2b2..18badeaf 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -5,7 +5,6 @@ use hf_hub::{ }; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; -use regex::Regex; use serde::Deserialize; use std::env; use std::ffi::OsString; @@ -1652,7 +1651,11 @@ impl From<&str> for Gpu { "nvidia-l40s" => Gpu::L40S, "nvidia-a10g" => Gpu::A10G, "nvidia-h100-80gb-hbm3" => Gpu::H100, + "nvidia-h100-nvl" => Gpu::H100, + "nvidia-h100" => Gpu::H100, "nvidia-a100-sxm4-80gb" => Gpu::A100, + "nvidia-a100-sxm4-40gb" => Gpu::A100, + "nvidia-a100-80gb-pcie" => Gpu::A100, "nvidia-a100" => Gpu::A100, card => Gpu::Unknown(card.to_string()), } @@ -2075,14 +2078,7 @@ fn main() -> Result<(), LauncherError> { let cuda_graphs = match (&args.cuda_graphs, &quantize) { (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] - ( - None, - Some( - Quantization::Bitsandbytes - | Quantization::BitsandbytesNf4 - | Quantization::BitsandbytesFp4, - ), - ) => { + (None, Some(Quantization::Bitsandbytes)) => { tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); vec![] } @@ -2172,26 +2168,21 @@ fn main() -> Result<(), LauncherError> { } // capture adapter_id, path, revision in format of adapter_id=path@revision - let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap(); - if let Some(caps) = re.captures(adapter) { - let adapter_id = caps.get(1).map_or("", |m| m.as_str()); - let revision = caps.get(3).map(|m| m.as_str()); - - download_convert_model( - adapter_id, - revision, - args.trust_remote_code, - args.huggingface_hub_cache.as_deref(), - args.weights_cache_override.as_deref(), - running.clone(), - false, // avoid merging lora adapters if using multi-lora - )?; - } else { - return Err(LauncherError::ArgumentValidation(format!( - "Invalid LoRA adapter format: {}", - adapter - ))); - } + // path is disabled beforehand. + let mut splits = adapter.split("@"); + let adapter_id = splits.next().ok_or_else(|| { + LauncherError::ArgumentValidation("Missing adapter id".to_string()) + })?; + let revision = splits.next(); + download_convert_model( + adapter_id, + revision, + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + false, // avoid merging lora adapters if using multi-lora + )?; } } diff --git a/router/Cargo.toml b/router/Cargo.toml index 9258fe03..2e621dfc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -17,7 +17,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" -jsonschema = { version = "0.17.1", features = ["draft202012"] } +jsonschema = { version = "0.28.0" } metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" @@ -25,7 +25,7 @@ opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.13.0" outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" } rand = "0.8.5" -reqwest = { version = "0.11.20", features = [] } +reqwest = { version = "0.11.20", features = ["blocking"] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" diff --git a/router/src/config.rs b/router/src/config.rs index 5d07a293..4d5fcfa0 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -110,6 +110,24 @@ pub struct ClipVisionModel { patch_size: usize, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Idefics3 {} + +impl Idefics3 { + pub fn get_max_longest_edge(&self) -> usize { + 364 + } + + pub fn get_number_of_features(&self) -> usize { + 169 + } + + pub fn get_max_longest_edge_for_image_resize(&self) -> usize { + 1456 + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Idefics2 {} @@ -178,6 +196,7 @@ pub enum Config { Idefics, Mllama, Idefics2(Idefics2), + Idefics3(Idefics3), Ssm, GptBigcode, Granite, diff --git a/router/src/kserve.rs b/router/src/kserve.rs index c53fa481..ea85eb8c 100644 --- a/router/src/kserve.rs +++ b/router/src/kserve.rs @@ -205,6 +205,7 @@ pub async fn kserve_model_infer( let generate_request = GenerateRequest { inputs: str_input.to_string(), parameters: payload.parameters.clone(), + add_special_tokens: true, }; let infer = infer.clone(); let compute_type = compute_type.clone(); @@ -212,7 +213,7 @@ pub async fn kserve_model_infer( async move { generate_internal(infer, compute_type, Json(generate_request), span) .await - .map(|(_, Json(generation))| { + .map(|(_, _, Json(generation))| { let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); OutputChunk { name: output.name.clone(), diff --git a/router/src/lib.rs b/router/src/lib.rs index 84e9bc48..dbd36827 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -79,7 +79,7 @@ impl TokenizerTrait for tokenizers::Tokenizer { } } -impl<'a> TokenizerTrait for PyTokenizer<'a> { +impl TokenizerTrait for PyTokenizer<'_> { fn encode_trait( &self, query: String, @@ -170,6 +170,7 @@ impl TokenizerConfigToken { #[serde(tag = "processor_class")] pub enum HubPreprocessorConfig { Idefics2Processor(Idefics2Preprocessor), + Idefics3Processor(Idefics2Preprocessor), } impl HubPreprocessorConfig { diff --git a/router/src/validation.rs b/router/src/validation.rs index 8137ac58..7ac05b21 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -7,7 +7,6 @@ use crate::{ use crate::{PyTokenizer, Tokenizer}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; -use jsonschema::{Draft, JSONSchema}; use outlines_core::json_schema::to_regex as json_schema_to_regex; use rand::{thread_rng, Rng}; use serde_json::Value; @@ -355,9 +354,7 @@ impl Validation { }?; // Check if the json is a valid JSONSchema - JSONSchema::options() - .with_draft(Draft::Draft202012) - .compile(&json) + jsonschema::draft202012::meta::validate(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; // The schema can be valid but lack properties. @@ -614,6 +611,73 @@ fn image_tokens( image_string } + Idefics3(config) => { + const FAKE: &str = ""; + const IMAGE: &str = ""; + const GLOBAL_IMG: &str = ""; + + let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize(); + + // resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio + let (height, width) = if height > max_longest_edge_for_image_resize + || width > max_longest_edge_for_image_resize + { + let aspect_ratio = height as f32 / width as f32; + if height > width { + ( + max_longest_edge_for_image_resize, + (max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize, + ) + } else { + ( + (max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize, + max_longest_edge_for_image_resize, + ) + } + } else { + (height, width) + }; + + let image_seq_len = config.get_number_of_features(); + let max_edge = config.get_max_longest_edge(); + + let (image_rows, image_cols) = if height > max_edge || width > max_edge { + ( + (height as f32 / max_edge as f32).ceil() as usize, + (width as f32 / max_edge as f32).ceil() as usize, + ) + } else { + (0, 0) + }; + + let mut image_string = String::new(); + + if image_rows == 0 && image_cols == 0 { + // Single image case + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } else { + // Split image case + for n_h in 0..image_rows { + for n_w in 0..image_cols { + image_string.push_str(FAKE); + image_string.push_str(&format!("", n_h + 1, n_w + 1)); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + } + image_string.push('\n'); + } + + image_string.push('\n'); + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } + + image_string + } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), Qwen2Vl(config) => format!( @@ -647,7 +711,8 @@ fn prepare_input( static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { Some( - config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), + config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_) + | Qwen2Vl(_)), ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); @@ -1164,12 +1229,11 @@ mod tests { assert!( chunks == vec![ - Chunk::Text("test".to_string()).into(), + Chunk::Text("test".to_string()), Chunk::Image(Image { data: pixel_data.clone(), mimetype: "image/gif".to_string() }) - .into() ], "Failed to process images", ); @@ -1224,17 +1288,15 @@ mod tests { assert!( chunks == vec![ - Chunk::Text("test".to_string()).into(), + Chunk::Text("test".to_string()), + Chunk::Image(Image { + data: pixel_data.clone(), + mimetype: "image/gif".to_string() + }), Chunk::Image(Image { data: pixel_data.clone(), mimetype: "image/gif".to_string() }) - .into(), - Chunk::Image(Image { - data: pixel_data.clone(), - mimetype: "image/gif".to_string() - }) - .into() ], "Failed to process images", ); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 12d58532..25959e0e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] # Released on: June 13, 2024 # https://releases.rs/docs/1.79.0/ -channel = "1.80.1" +channel = "1.84.0" components = ["rustfmt", "clippy"] diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a9cdf782..9a946d97 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 +flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer index f0a27622..d5f684ba 100644 --- a/server/Makefile-flashinfer +++ b/server/Makefile-flashinfer @@ -1,2 +1,5 @@ install-flashinfer: - pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 + # We need fsspec as an additional dependency, but + # `pip install flashinfer` cannot resolve it. + pip install fsspec + pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/server/poetry.lock b/server/poetry.lock index 7cf440dd..93db8dc9 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -290,22 +290,23 @@ tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "bitsandbytes" -version = "0.43.3" +version = "0.45.0" description = "k-bit optimizers and matrix multiplication routines." optional = true python-versions = "*" files = [ - {file = "bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cc99507c352be0715098b2c7577b690dd158972dc4ea10c7495bac104c7c79f0"}, - {file = "bitsandbytes-0.43.3-py3-none-win_amd64.whl", hash = "sha256:257f6552f2144748a84e6c44e1f7a98f3da888f675ed74e18fd7f7eb13c6cafa"}, + {file = "bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:0f0323de1ff1fdf8383e79bdad1283516a4c05a6fd2b44a363bf4e059422305b"}, + {file = "bitsandbytes-0.45.0-py3-none-win_amd64.whl", hash = "sha256:ebbf96e0ecb466716a65ecdeaef3fa1983575447b9ab66b74e5211892507c6ff"}, ] [package.dependencies] numpy = "*" torch = "*" +typing_extensions = ">=4.8.0" [package.extras] benchmark = ["matplotlib", "pandas"] -test = ["scipy"] +test = ["lion_pytorch", "scipy"] [[package]] name = "certifi" @@ -1289,12 +1290,12 @@ files = [ [[package]] name = "marlin-kernels" -version = "0.3.6" +version = "0.3.7" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.6+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:afedaa9a15e8991442bc8c81f62833fbf5c1556ae9d7a5a9e13b747ce97beef9"}, + {file = "marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:bb416d14623dc0ad0eeb2835446c37a41f994555f1baec8701de6d4c1fc17ec8"}, ] [package.dependencies] @@ -1302,16 +1303,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.6" +version = "0.3.7" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.6+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:c0c05621d5e87144415d8a6e439072bd844d5f3cb55e4c4c69eabdc4c94610f4"}, + {file = "marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:a89bb61d718002d4432158641bce95c6fd68f9ee1a7d5402dd283903397f3185"}, ] [package.dependencies] @@ -1319,16 +1320,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.6" +version = "0.3.7" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.6+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:3be4662c8d25a3cdb1793dafe0e2e76dd600913a69a468e2c68d1fed4e149255"}, + {file = "marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:ed938d196fc5e9cce9fc44cd2b889d5adc5ca7475c8a23858f1474d29e38bdbf"}, ] [package.dependencies] @@ -1336,16 +1337,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.6" +version = "0.3.7" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.6+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:89eac9d46bc084a256b538afda6053683eb7e505db0e0d4f6dbeca32368caac6"}, + {file = "marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:113c54f68565ad476ca12366b4de92131fa3e9ddb16cbe8ad63272972a15ac28"}, ] [package.dependencies] @@ -1353,7 +1354,7 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mdurl" @@ -4097,4 +4098,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "c7fdcff2b752cd3beb3995c1ecd15f0f4d9b4e117048b06ab991c6d0e0c86ff3" +content-hash = "0ead8472620eeef6f9ff81f70bcb48403f9c831b6914245efa5e249724d80d0b" diff --git a/server/pyproject.toml b/server/pyproject.toml index 0d56e9c7..0386ae55 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.4" typer = "^0.12.5" accelerate = {version = "^1.1.0", optional = true} -bitsandbytes = { version = "^0.43.0", optional = true } +bitsandbytes = { version = "^0.45.0", optional = true } safetensors = "^0.4.5" loguru = "^0.7.2" opentelemetry-api = "^1.27.0" @@ -48,10 +48,10 @@ attention-kernels = [ { url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.6/marlin_kernels-0.3.6+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index a27c1055..ab0312e4 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj(): # assert the result expected = { + (3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc), + (3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), @@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility(): result = get_mlp_weights(3, mock_layer) expected = { + (3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc), + (3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), @@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility(): result = get_mlp_weights(3, mock_layer) expected = { + (3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc), + (3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index f1edd9a0..cdcfe91b 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -6,9 +6,11 @@ from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple, Type, Union +from loguru import logger import torch from peft import LoraConfig as _LoraConfig from torch.distributed import ProcessGroup +from text_generation_server.utils.log import log_master from text_generation_server.adapters.config import AdapterConfig, ModuleMap @@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights): lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers + # import ipdb; ipdb.set_trace() for layer_id in range(nlayers): key = (layer_id, layer_type) + if key not in target_to_layer: + # There is no layer of this type in the model + log_master( + logger.warning, + f"Key specified in lora weights but not found in base model: {key}", + ) + return None + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 3038602e..7b5af3c4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -60,8 +60,7 @@ def paged_attention( from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( - # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. - query.contiguous(), + query, paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, @@ -231,8 +230,7 @@ def attention( softcap = 0.0 return prefill_with_paged_kv_state.get().forward( - # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. - query.contiguous(), + query, causal=causal, paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 26a72d9b..909eea27 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state( num_kv_heads: int, head_size: int, page_size: int, - dtype: torch.dtype, + kv_dtype: torch.dtype, + q_dtype: torch.dtype, window_left: int, ): """ @@ -91,9 +92,10 @@ def use_prefill_with_paged_kv_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=dtype, + kv_data_type=kv_dtype, + q_data_type=q_dtype, page_size=page_size, - window_left=window_left, + window_left=-1 if window_left is None else window_left, ) yield finally: @@ -113,41 +115,6 @@ def create_prefill_state( ) -@contextmanager -def use_prefill_state( - *, - state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, - cu_seqlens: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - token = prefill_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - kv_indptr=cu_seqlens, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_state.reset(token) - - def create_decode_state( *, device: torch.device, @@ -205,7 +172,7 @@ def use_decode_state( head_size: int, page_size: int, kv_cache_dtype: torch.dtype, - dtype: torch.dtype, + q_dtype: torch.dtype, window_left: int, ): """ @@ -242,8 +209,8 @@ def use_decode_state( head_dim=head_size, page_size=page_size, data_type=kv_cache_dtype, - q_data_type=dtype, - window_left=window_left, + q_data_type=q_dtype, + window_left=-1 if window_left is None else window_left, ) yield finally: diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index a5ab0ae9..146c15e9 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -5,6 +5,10 @@ from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) from loguru import logger import vllm._custom_ops as ops @@ -73,11 +77,44 @@ def paged_attention( # limitations under the License. # + if ATTENTION == "flashdecoding": + max_q = 1 + max_k = max_s + import flash_attn_2_cuda + + if softcap is None: + softcap = 0.0 + out = flash_attn_2_cuda.varlen_fwd( + query, + kv_cache.key, + kv_cache.value, + None, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, # pad_k + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + softcap, + False, # return softmax + None, # generator + ) + return out[0] + if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache.value.shape[3] + # block_size = kv_cache.value.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape num_kv_heads = kv_cache.key.shape[1] @@ -126,17 +163,17 @@ def paged_attention( else: # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( + tmp_output = torch.zeros( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=out.dtype, device=out.device, ) - exp_sums = torch.empty( + exp_sums = torch.zeros( size=(num_seqs, num_heads, max_num_partitions), dtype=torch.float32, device=out.device, ) - max_logits = torch.empty_like(exp_sums) + max_logits = torch.zeros_like(exp_sums) if not use_custom: ops.paged_attention_v2( @@ -247,14 +284,15 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( query, - key, - value, + # flashdecoding: pass the KV caches, paged: pass the KV. + kv_cache.key if ATTENTION == "flashdecoding" else key, + kv_cache.value if ATTENTION == "flashdecoding" else value, out, seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - None, + seqlen.cu_seqlen_k, None, None, + block_tables if ATTENTION == "flashdecoding" else None, None, seqlen.max_q, seqlen.max_k, diff --git a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py index e63c5212..ebcc06d6 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -3,8 +3,14 @@ from typing import List, Optional, Union import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationType -from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, + requantize_with_max_scale, + normalize_e4m3fn_to_e4m3fnuz, +) from text_generation_server.utils.weights import Weights, WeightsLoader +from text_generation_server.utils.import_utils import SYSTEM class W8ANFpLoader(WeightsLoader): @@ -47,11 +53,10 @@ class W8ANFpLoader(WeightsLoader): weight_scale = None if self.load_weight_scale: - weight_scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: @@ -87,7 +92,8 @@ class W8ANFpLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: @@ -141,6 +147,17 @@ class W8ANFpLoader(WeightsLoader): else None ) + if self.load_weight_scale and SYSTEM == "rocm": + w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, weight_scale, input_scale + ) + + if weight_scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, weight_scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=weight_scale, @@ -153,11 +170,10 @@ class W8ANFpLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) weight_scale = None if self.load_weight_scale: - weight_scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 1e5c8b3d..4e83ec9d 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -19,6 +19,9 @@ try: except ImportError: marlin_kernels = None +quant_dtype: torch.dtype = ( + torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn +) if SYSTEM == "cuda" and marlin_kernels is not None: major, minor = torch.cuda.get_device_capability() @@ -60,25 +63,58 @@ def normalize_e4m3fn_to_e4m3fnuz( weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - assert weight.dtype == torch.float8_e4m3fn - # The bits pattern 10000000(-128) represents zero in e4m3fn - # but NaN in e4m3fnuz. So here we set it to 0. - # https://onnx.ai/onnx/technical/float8.html - weight_as_int8 = weight.view(torch.int8) - ROCM_FP8_NAN_AS_INT = -128 - weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 - weight = weight_as_int8.view(torch.float8_e4m3fnuz) + if weight.dtype == torch.float8_e4m3fn: + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) - # For the same bits representation, e4m3fnuz value is half of - # the e4m3fn value, so we should double the scaling factor to - # get the same dequantized value. - # https://onnx.ai/onnx/technical/float8.html - weight_scale = weight_scale * 2.0 - if input_scale is not None: - input_scale = input_scale * 2.0 + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 return weight, weight_scale, input_scale +def per_tensor_dequantize( + tensor: torch.Tensor, + inv_scale: Union[float, torch.Tensor], + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + fake_qweight = tensor.to(dtype) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def requantize_with_max_scale( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: int, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max().float() + + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx], dtype + ) + weight[start:end, :], max_w_scale_normalized = fp8_quantize( + weight_dq, max_w_scale + ) + start = end + + return weight, max_w_scale_normalized + + def fp8_quantize( weight: torch.Tensor, scale: Optional[torch.Tensor] = None, @@ -96,7 +132,7 @@ def fp8_quantize( shape = weight.shape qweight, scale = marlin_kernels.scaled_fp8_quant( weight.reshape(-1, shape[-1]), - dtype=qdtype, + dtype=quant_dtype, scale=scale, scale_ub=scale_upper_bound, # TODO: don't do this when we have to use the Torch kernel. @@ -116,6 +152,8 @@ def fp8_quantize( qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) scale = scale.float().reciprocal() else: + if SYSTEM == "rocm": + scale = scale / 2.0 # Use reciprocal to avoid more expensive division. qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) @@ -141,17 +179,18 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + scale.reshape(-1).expand(w.shape[0]) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) return Fp8Weight( weight=w, @@ -178,6 +217,7 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: scale = weights.get_packed_sharded( f"{prefix}.weight_scale", @@ -185,7 +225,8 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]) + if SYSTEM == "cuda": + scale = scale.reshape(-1).expand(w.shape[0]) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -243,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) + if SYSTEM == "rocm": + w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, scale, input_scale + ) + + if scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -259,16 +311,18 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + scale = scale.reshape(-1).expand(w.shape[0]) + input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) return Fp8Weight( weight=w, @@ -326,7 +380,7 @@ class Fp8Linear(torch.nn.Module): if CUTLASS_FP8_AVAILABLE: log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: - qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( + qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale ) @@ -443,6 +497,9 @@ class Fp8Linear(torch.nn.Module): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + elif SYSTEM == "rocm": + return scale.reshape(-1) return scale.reshape(-1).expand(shape[0]) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608..e2d24643 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -152,6 +152,9 @@ try: from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) + from text_generation_server.models.custom_modeling.idefics3 import ( + Idefics3ForConditionalGeneration, + ) from text_generation_server.models.custom_modeling.qwen2_vl import ( Qwen2VLForConditionalGeneration, ) @@ -188,6 +191,12 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", "multimodal": True, } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } LLAVA_NEXT = { "type": "llava_next", "name": "Llava Next (1.6)", @@ -1253,6 +1262,24 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == IDEFICS3: + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 1456}}, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( @@ -1422,6 +1449,9 @@ def get_model_with_lora_adapters( "up_proj", "down_proj", "qkv_proj", + # add c_* layers used in starcoder2 + "c_proj", + "c_fc", ] for layer_name in adapter_layers: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 10309006..28db42fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -515,9 +515,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=f"{prefix}.layers.0", config=config, weights=weights, ) @@ -533,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaCrossLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -546,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -561,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=last_layer_id, - prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.model.layers.{last_layer_id}" - ), + prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, ) ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -629,19 +615,24 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}.model.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: @@ -652,11 +643,13 @@ class FlashLlamaForCausalLM(torch.nn.Module): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier + prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" + with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, + prefix, + weights, ) # Used in Granite diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index c793982d..5e090369 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -32,6 +32,8 @@ from text_generation_server.layers.attention import ( Seqlen, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig): ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=config.use_bias, ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): @@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights): class Starcoder2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=config.use_bias, + bias=getattr(config, "use_bias", False), ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module): kv_scales=self.kv_scales, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Starcoder2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module): ) ) # Fuse gate and up proj - self.c_fc = TensorParallelColumnLinear.load( + c_fc = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.c_fc", weights=weights, bias=config.use_bias, ) - self.c_proj = TensorParallelRowLinear.load( + c_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.c_proj", weights=weights, bias=config.use_bias, ) - def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id=index, + layer_names=[f"{prefix}.c_fc"], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.c_proj = TensorParallelAdapterRowLinear.load( + c_proj, + index, + "c_proj", + process_group=weights.process_group, + ) + + def forward(self, hidden_states, adapter_data): + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - return self.c_proj(hidden_states) + return self.c_proj(hidden_states, adapter_data) class Starcoder2GatedMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=config.use_bias, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=config.use_bias, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) STARCODER2_NORMALIZATION_CLASSES = { @@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id ) self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( @@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py new file mode 100644 index 00000000..580398cb --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -0,0 +1,584 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics3 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics3VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics3VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics3VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics3EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics3VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics3VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics3Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics3EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics3VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics3VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics3Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics3SimpleMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj + + def forward(self, x): + return self.proj(x) + + +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight` + # since Idefics3 uses the `embed_tokens` for the final prediction + # config.text_config.tie_word_embeddings = True + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics3VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 82e409a6..94b8522d 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5d376990..d097c54f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1288,7 +1288,7 @@ class FlashCausalLM(Model): weights_loader=weights_loader, ) - prefix = "" + prefix = None model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) @@ -1595,7 +1595,9 @@ class FlashCausalLM(Model): if max_total_tokens is None: if get_support_chunking(): model_max_length = self.tokenizer.model_max_length - max_position_embeddings = self.config.max_position_embeddings + max_position_embeddings = getattr( + self.config, "max_position_embeddings", model_max_length + ) max_total_tokens = min( num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings ) @@ -1663,7 +1665,7 @@ class FlashCausalLM(Model): for seqlen in tuning_sequences: log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) + self.tunableop_warmup(seqlen, max_total_tokens) torch.cuda.tunable.write_file(tunableop_filepath) if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": torch.cuda.tunable.tuning_enable(False) @@ -1710,7 +1712,7 @@ class FlashCausalLM(Model): assert max_total_tokens is not None return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def tunableop_warmup(self, seqlen: int): + def tunableop_warmup(self, seqlen: int, max_bt: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) @@ -1724,11 +1726,15 @@ class FlashCausalLM(Model): [0, seqlen], device=self.device, dtype=torch.int32 ) max_s = seqlen + + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(seqlen) + block_tables = block_tables.reshape((seqlen, max_bt)) + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=1, max_k=seqlen, ) @@ -1738,7 +1744,7 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, + block_tables=block_tables, seqlen=seqlen, slots=slots, max_s=max_s, @@ -2480,7 +2486,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, - dtype=self.dtype, + kv_dtype=self.kv_cache_dtype, + q_dtype=self.dtype, window_left=self.sliding_window, ) else: @@ -2494,6 +2501,6 @@ class FlashCausalLM(Model): head_size=self.head_size, page_size=BLOCK_SIZE, kv_cache_dtype=self.kv_cache_dtype, - dtype=self.dtype, + q_dtype=self.dtype, window_left=self.sliding_window, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 81b4369b..db78341d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLM, ) from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -23,6 +24,40 @@ tracer = trace.get_tracer(__name__) IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 +def _prompt_split_image( + *, + image_seq_len: int, + image_rows: int, + image_cols: int, + fake_token_around_image: str, + image_token: str, + global_img_token: str, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -54,10 +89,26 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str if processor.image_processor.do_image_splitting: image_str *= 5 return image_str + if config.model_type == "idefics3": + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) + image_str = _prompt_split_image( + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) + return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) - from loguru import logger log_master( logger.info, @@ -194,12 +245,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch): raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - image_inputs = processor.image_processor(images, return_tensors="pt") + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + image_inputs = processor.image_processor( + images, return_tensors="pt", **kwargs + ) else: image_inputs = None - batch_inputs = [] - max_truncation = 0 + batch_tokenized_inputs = [] + max_length = 0 image_id = 0 for r in requests: full_text = "" @@ -214,16 +274,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_id += 1 full_text = image_text_replacement_fixup(config, full_text) - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=not config.model_type == "paligemma", - )["input_ids"] + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) return batch_tokenized_inputs, image_inputs diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 09254b68..50abfafd 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -281,6 +281,12 @@ def get_mlp_weights(i, layer): if hasattr(mlp, "up_proj"): weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + if hasattr(mlp, "c_fc"): + weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc) + + if hasattr(mlp, "c_proj"): + weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj) + if hasattr(mlp, "down_proj"): weights[(i, "down_proj")] = ( f"model.layers.{i}.mlp.down_proj", diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 82aeba6c..1b766ddf 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -1,6 +1,6 @@ import os import torch - +from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -18,10 +18,11 @@ class FakeBarrier: pass -class FakeGroup: +class FakeGroup(ProcessGroup): def __init__(self, rank, size): self._rank = rank self._size = size + super().__init__(rank, size) def allreduce(self, *args, **kwargs): return FakeBarrier()