This commit is contained in:
drbh 2025-01-09 13:13:45 +00:00 committed by GitHub
commit 33299d1ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 772 additions and 65 deletions

View File

@ -20,7 +20,7 @@ jobs:
- name: Install Protocol Buffers compiler - name: Install Protocol Buffers compiler
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev sudo apt-get install -y protobuf-compiler libprotobuf-dev clang libavcodec-dev libavfilter-dev libavdevice-dev libavformat-dev libavutil-dev pkg-config
- name: Install Launcher - name: Install Launcher
id: install-launcher id: install-launcher

View File

@ -43,7 +43,9 @@ jobs:
- name: Install - name: Install
run: | run: |
sudo apt update sudo apt update
sudo apt install python3.11-dev -y sudo apt install python3.11-dev python3.11-venv python3-pip clang libavcodec-dev libavfilter-dev libavdevice-dev libavformat-dev libavutil-dev pkg-config -y
export PKG_CONFIG_PATH=$PKG_CONFIG_PATH:/usr/lib/x86_64-linux-gnu/pkgconfig
python -m pip install --upgrade pip
make install-cpu make install-cpu
- name: Run server tests - name: Run server tests
run: | run: |

89
Cargo.lock generated
View File

@ -267,7 +267,7 @@ version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96" checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.69.5",
"cc", "cc",
"cmake", "cmake",
"dunce", "dunce",
@ -454,6 +454,24 @@ dependencies = [
"which", "which",
] ]
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.6.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 2.0.89",
]
[[package]] [[package]]
name = "bit-set" name = "bit-set"
version = "0.5.3" version = "0.5.3"
@ -487,6 +505,15 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bitreader"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "886559b1e163d56c765bc3a985febb4eee8009f625244511d8ee3c432e08c066"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "bitstream-io" name = "bitstream-io"
version = "2.6.0" version = "2.6.0"
@ -1194,6 +1221,15 @@ dependencies = [
"zune-inflate", "zune-inflate",
] ]
[[package]]
name = "fallible_collections"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd"
dependencies = [
"hashbrown 0.13.2",
]
[[package]] [[package]]
name = "fancy-regex" name = "fancy-regex"
version = "0.11.0" version = "0.11.0"
@ -1219,6 +1255,31 @@ dependencies = [
"simd-adler32", "simd-adler32",
] ]
[[package]]
name = "ffmpeg-next"
version = "7.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da02698288e0275e442a47fc12ca26d50daf0d48b15398ba5906f20ac2e2a9f9"
dependencies = [
"bitflags 2.6.0",
"ffmpeg-sys-next",
"libc",
]
[[package]]
name = "ffmpeg-sys-next"
version = "7.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bc3234d0a4b2f7d083699d0860c6c9dd83713908771b60f94a96f8704adfe45"
dependencies = [
"bindgen 0.70.1",
"cc",
"libc",
"num_cpus",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
version = "0.4.2" version = "0.4.2"
@ -1512,6 +1573,15 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
dependencies = [
"ahash",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.14.5" version = "0.14.5"
@ -2471,6 +2541,20 @@ dependencies = [
"syn 2.0.89", "syn 2.0.89",
] ]
[[package]]
name = "mp4parse"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570"
dependencies = [
"bitreader",
"byteorder",
"fallible_collections",
"log",
"num-traits",
"static_assertions",
]
[[package]] [[package]]
name = "multimap" name = "multimap"
version = "0.10.0" version = "0.10.0"
@ -4425,6 +4509,7 @@ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"clap 4.5.21", "clap 4.5.21",
"csv", "csv",
"ffmpeg-next",
"futures", "futures",
"futures-util", "futures-util",
"hf-hub", "hf-hub",
@ -4436,6 +4521,7 @@ dependencies = [
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mp4parse",
"ngrok", "ngrok",
"nohash-hasher", "nohash-hasher",
"once_cell", "once_cell",
@ -4449,6 +4535,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"sysinfo", "sysinfo",
"tempfile",
"thiserror", "thiserror",
"tokenizers", "tokenizers",
"tokio", "tokio",

View File

@ -20,6 +20,20 @@ FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev python3.11-dev
RUN apt-get update && apt-get install -y \
ffmpeg \
libavcodec-dev \
libavfilter-dev \
libavdevice-dev \
libavformat-dev \
libavutil-dev \
libswscale-dev \
pkg-config \
libclang-dev \
clang \
&& rm -rf /var/lib/apt/lists/*
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -27,7 +41,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json RUN cargo chef cook --profile release-opt --features video --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
@ -40,7 +54,7 @@ COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt --frozen RUN cargo build --profile release-opt --frozen --features video
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
@ -61,18 +75,18 @@ ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
ca-certificates \ ca-certificates \
ccache \ ccache \
curl \ curl \
git && \ git && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install conda # Install conda
# translating Docker's TARGETPLATFORM into mamba arches # translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \ RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \ "linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \ *) MAMBA_ARCH=x86_64 ;; \
esac && \ esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \ RUN chmod +x ~/mambaforge.sh && \
@ -82,12 +96,15 @@ RUN chmod +x ~/mambaforge.sh && \
# Install pytorch # Install pytorch
# On arm64 we exit with an error code # On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \ RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \ "linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \ *) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" "openssl>=3.3.0" ;; \
esac && \ esac && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
RUN /opt/conda/bin/conda install -y pyOpenSSL
# CUDA kernels builder image # CUDA kernels builder image
FROM pytorch-install AS kernel-builder FROM pytorch-install AS kernel-builder
@ -95,8 +112,8 @@ ARG MAX_JOBS=8
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX" ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build cmake \ ninja-build cmake \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Build Flash Attention CUDA kernels # Build Flash Attention CUDA kernels
FROM kernel-builder AS flash-att-builder FROM kernel-builder AS flash-att-builder
@ -188,12 +205,15 @@ ENV HF_HOME=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libssl-dev \ libssl-dev \
ca-certificates \ ca-certificates \
make \ make \
curl \ curl \
git \ git \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Add ffmpeg libraries to the path
ENV LD_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
# Copy conda with PyTorch installed # Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda COPY --from=pytorch-install /opt/conda /opt/conda
@ -239,6 +259,8 @@ RUN cd server && \
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries # Required to find libpython within the rust binaries
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
ENV LD_PRELOAD="/opt/conda/lib/libcrypto.so.3"
# This is needed because exl2 tries to load flash-attn # This is needed because exl2 tries to load flash-attn
# And fails with our builds. # And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1 ENV EXLLAMA_NO_FLASH_ATTN=1
@ -247,9 +269,9 @@ ENV EXLLAMA_NO_FLASH_ATTN=1
# The binaries change on every build given we burn the SHA into them # The binaries change on every build given we burn the SHA into them
# The deps change less often. # The deps change less often.
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
g++ \ g++ \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -258,6 +280,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Copy the ffmpeg libraries
COPY --from=builder /usr/lib/x86_64-linux-gnu/* /usr/lib/x86_64-linux-gnu-copy/
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu-copy"
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base AS sagemaker FROM base AS sagemaker

View File

@ -9,7 +9,7 @@ use thiserror::Error;
use tonic::transport; use tonic::transport;
use tonic::Status; use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk}; pub use v3::{Chunk, Image, Input, InputChunk, Video};
#[async_trait] #[async_trait]
pub trait Health { pub trait Health {
@ -79,6 +79,20 @@ impl ChunksToString for Vec<InputChunk> {
let encoded = STANDARD.encode(data); let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
} }
Some(Chunk::Video(Video {
data,
mimetype,
width,
height: _,
frames: _,
})) => {
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
let encoded = STANDARD.encode(data);
output.push_str(&format!(
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
width, mimetype, encoded, mimetype
));
}
// We don't create empty chunks, so this should be unreachable. // We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"), None => unreachable!("Chunks should never be empty"),
}); });

View File

@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens, StoppingCriteriaParameters, Tokens, Video,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;

View File

@ -301,6 +301,7 @@ impl TensorRtLlmBackendV2 {
1 => match request.inputs.first().expect("Single item-chunk") { 1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()), Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
Chunk::Video(_) => Err(ValidationError(UnsupportedModality("video"))),
}, },
} }
} }

View File

@ -15,7 +15,7 @@ pub use grpc_client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, StoppingCriteriaParameters, Video,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;

View File

@ -439,6 +439,13 @@ impl State {
data: image.data, data: image.data,
mimetype: image.mimetype, mimetype: image.mimetype,
}), }),
Chunk::Video(video) => client::Chunk::Video(client::Video {
data: video.data,
mimetype: video.mimetype,
width: video.width,
height: video.height,
frames: video.num_frames,
}),
}), }),
}) })
.collect(), .collect(),

View File

@ -1922,6 +1922,24 @@
] ]
} }
} }
},
{
"type": "object",
"required": [
"video_url",
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"video_url"
]
},
"video_url": {
"$ref": "#/components/schemas/Url"
}
}
} }
], ],
"discriminator": { "discriminator": {

View File

@ -115,15 +115,17 @@
buildInputs = buildInputs =
[ [
benchmark benchmark
launcher cargo
router
server
client client
clippy
ffmpeg
launcher
openssl.dev openssl.dev
pkg-config pkg-config
cargo router
rustPlatform.bindgenHook
rustfmt rustfmt
clippy server
] ]
++ (with python3.pkgs; [ ++ (with python3.pkgs; [
docker docker

View File

@ -0,0 +1,19 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant"
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1733450914,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,75 @@
import pytest
import json
import requests
@pytest.fixture(scope="module")
def qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-7B-Instruct",
max_input_length=10_000,
max_batch_prefill_tokens=10_000,
max_total_tokens=10_001,
cuda_graphs=[0],
) as handle:
yield handle
@pytest.fixture(scope="module")
async def qwen2_vl(qwen2_vl_handle):
await qwen2_vl_handle.health(300)
return qwen2_vl_handle.client
@pytest.mark.asyncio
async def test_qwen2_vl_simpl(qwen2_vl, response_snapshot):
responses = requests.post(
f"{qwen2_vl.base_url}/v1/chat/completions",
headers=qwen2_vl.headers,
json={
"model": "tgi",
"messages": [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4"
},
},
{
"type": "text",
"text": "Describe this video.",
},
],
},
],
"seed": 42,
"max_tokens": 100,
"stream": True,
},
)
# iterate over the response in chunks
count = 0
full_text = ""
last_response = None
for chunk in responses.iter_content(chunk_size=1024):
if chunk:
count += 1
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
for line in lines:
if line == "[DONE]":
break
print("=", line)
try:
response = json.loads(line)
# print(response)
last_response = response
full_text += response["choices"][0]["delta"]["content"]
except json.JSONDecodeError:
pass
assert last_response == response_snapshot

View File

@ -19,6 +19,26 @@ defaultCrateOverrides
}; };
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
ffmpeg-sys-next = attrs: {
nativeBuildInputs = [
pkg-config
];
buildInputs = [
rustPlatform.bindgenHook
ffmpeg
];
};
ffmpeg-next = attrs: {
# Somehow the variables that are passed are mangled, so they are not
# correctly passed to the ffmpeg-next build script. Worth investigating
# more since it's probably a bug in crate2nix or buildRustCrate.
postPatch = ''
substituteInPlace build.rs \
--replace-fail "DEP_FFMPEG_" "DEP_FFMPEG_SYS_NEXT_"
'';
};
grpc-metadata = attrs: { grpc-metadata = attrs: {
src = filter { src = filter {
root = ../backends/grpc-metadata; root = ../backends/grpc-metadata;

View File

@ -5,9 +5,11 @@
cmake, cmake,
isort, isort,
ninja, ninja,
rustPlatform,
which, which,
cudaPackages, cudaPackages,
openssl, openssl,
ffmpeg,
pkg-config, pkg-config,
poetry, poetry,
protobuf, protobuf,
@ -26,6 +28,7 @@
mkShell { mkShell {
nativeBuildInputs = nativeBuildInputs =
[ [
rustPlatform.bindgenHook
black black
isort isort
pkg-config pkg-config
@ -53,6 +56,7 @@ mkShell {
buildInputs = buildInputs =
[ [
openssl.dev openssl.dev
ffmpeg
] ]
++ (with python3.pkgs; [ ++ (with python3.pkgs; [
venvShellHook venvShellHook

View File

@ -64,12 +64,31 @@ message Image {
string mimetype = 2; string mimetype = 2;
} }
message Video {
/// Binary video data (array of RGB data)
bytes data = 1;
/// Video MIME type.
string mimetype = 2;
/// Video width
uint32 width = 3;
/// Video height
uint32 height = 4;
/// Total number of frames
uint32 frames = 5;
}
message InputChunk { message InputChunk {
oneof chunk { oneof chunk {
/// Plain text data /// Plain text data
string text = 1; string text = 1;
/// Image data /// Image data
Image image = 2; Image image = 2;
/// Video URLs
Video video = 3;
} }
} }

View File

@ -14,20 +14,23 @@ async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] } axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16" axum-tracing-opentelemetry = "0.16"
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
ffmpeg-next = { version = "7.1.0", optional = true }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10" itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = { workspace = true } metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true } metrics-exporter-prometheus = { workspace = true }
mp4parse = { version = "0.17.0", optional = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" } outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
tempfile = { version = "3.10.1", optional = true }
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [ tokio = { version = "1.32.0", features = [
@ -74,3 +77,4 @@ default = ["ngrok"]
ngrok = ["dep:ngrok"] ngrok = ["dep:ngrok"]
google = [] google = []
kserve = [] kserve = []
video = ["ffmpeg-next", "mp4parse", "tempfile"]

View File

@ -1173,6 +1173,7 @@ pub struct Url {
pub enum MessageChunk { pub enum MessageChunk {
Text { text: String }, Text { text: String },
ImageUrl { image_url: Url }, ImageUrl { image_url: Url },
VideoUrl { video_url: Url },
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@ -1229,6 +1230,9 @@ impl From<Message> for TextMessage {
.map(|chunk| match chunk { .map(|chunk| match chunk {
MessageChunk::Text { text } => text, MessageChunk::Text { text } => text,
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
MessageChunk::VideoUrl { video_url } => {
format!("<video>({})", video_url.url)
}
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(""), .join(""),

View File

@ -22,6 +22,15 @@ use tokio::sync::oneshot;
use tracing::{instrument, Span}; use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex}; use {once_cell::sync::Lazy, regex::Regex};
#[cfg(feature = "video")]
use ffmpeg_next::{
format::Pixel,
media::Type,
software::scaling::{context::Context, flag::Flags},
};
#[cfg(feature = "video")]
use std::io::Write;
static DEFAULT_GENERATION_LENGTH: u32 = 1024; static DEFAULT_GENERATION_LENGTH: u32 = 1024;
/// Validation /// Validation
@ -536,6 +545,140 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string() .to_string()
} }
#[cfg(not(feature = "video"))]
pub fn fetch_video(
_input: &str,
_target_width: u32,
_target_height: u32,
) -> Result<ProcessedVideo, ValidationError> {
Err(ValidationError::VideoNotSupported)
}
#[cfg(feature = "video")]
pub fn fetch_video(
input: &str,
target_width: u32,
target_height: u32,
) -> Result<ProcessedVideo, ValidationError> {
let (data, mimetype) =
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
let url = &input["<video>(".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
(data, "video/mp4".to_string())
} else if input.starts_with("<video>(data:") {
let content = &input["<video>(data:".len()..input.len() - 1];
let tokens: Vec<&str> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let mimetype = tokens[0];
let content = tokens[1];
if !content.starts_with("base64,") {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let data = STANDARD.decode(&content["base64,".len()..])?;
(data, mimetype.to_string())
} else {
return Err(ValidationError::InvalidVideoContent(input.to_string()));
};
// init ffmpeg
ffmpeg_next::init().map_err(|_| ValidationError::FFmpegError)?;
// create temporary file for ffmpeg input
let mut temp_file = tempfile::NamedTempFile::new().map_err(ValidationError::IoError)?;
temp_file
.write_all(&data)
.map_err(ValidationError::IoError)?;
let mut ictx =
ffmpeg_next::format::input(&temp_file.path()).map_err(|_| ValidationError::FFmpegError)?;
let input = ictx
.streams()
.best(Type::Video)
.ok_or(ValidationError::FFmpegError)?;
let video_stream_index = input.index();
let context_decoder = ffmpeg_next::codec::context::Context::from_parameters(input.parameters())
.map_err(|_| ValidationError::FFmpegError)?;
let mut decoder = context_decoder
.decoder()
.video()
.map_err(|_| ValidationError::FFmpegError)?;
let width = target_width;
let height = target_height;
let mut scaler = Context::get(
decoder.format(),
decoder.width(), // original width
decoder.height(),
Pixel::RGB24,
width, // target width
height,
Flags::BILINEAR,
)
.map_err(|_| ValidationError::FFmpegError)?;
let mut frame_index = 0;
let mut captured_frame_index = 0;
let mut frames = vec![];
let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg_next::decoder::Video,
raw_fps: f32|
-> Result<(), ffmpeg_next::Error> {
let mut decoded = ffmpeg_next::util::frame::video::Video::empty();
let fps = raw_fps.floor();
while decoder.receive_frame(&mut decoded).is_ok() {
let mut rgb_frame = ffmpeg_next::util::frame::video::Video::empty();
scaler.run(&decoded, &mut rgb_frame)?;
if frame_index as f32 % fps == 0.0 {
captured_frame_index += 1;
// Create new buffer without padding
let mut frame_data =
Vec::with_capacity((rgb_frame.width() * rgb_frame.height() * 3) as usize);
let src_data = rgb_frame.data(0);
let row_size = rgb_frame.width() as usize * 3;
// Copy each row without padding
for y in 0..rgb_frame.height() as usize {
let start = y * rgb_frame.stride(0);
let end = start + row_size;
frame_data.extend_from_slice(&src_data[start..end]);
}
frames.push(frame_data);
}
frame_index += 1;
}
Ok(())
};
for (stream, packet) in ictx.packets() {
// Floor the fps to get a whole number
let fps = (stream.rate().numerator() as f32 / stream.rate().denominator() as f32).floor();
if stream.index() == video_stream_index {
decoder
.send_packet(&packet)
.map_err(|_| ValidationError::FFmpegError)?;
receive_and_process_decoded_frames(&mut decoder, fps)
.map_err(|_| ValidationError::FFmpegError)?;
}
}
decoder
.send_eof()
.map_err(|_| ValidationError::FFmpegError)?;
Ok(ProcessedVideo {
mimetype,
height,
width,
frames,
sampled_frames: captured_frame_index,
})
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> { fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") { if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1]; let url = &input["![](".len()..input.len() - 1];
@ -624,6 +767,26 @@ fn image_tokens(
} }
} }
fn video_tokens(config: &Config, height: u32, width: u32, sampled_frames: f32) -> String {
use Config::*;
match config {
Qwen2Vl(_config) => {
let min_frames = 2_f32;
let max_frames = 256_f32;
// make sure the frames are within the range and are even
let nframes = (sampled_frames).max(min_frames).min(max_frames);
let nframes = (nframes / 2.0).round() as usize * 2;
let num_tokens = nframes * height as usize * width as usize / 1541;
format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|video_pad|>".repeat(num_tokens)
)
}
_ => unimplemented!("Video tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String { fn image_tokens_fixup(config: &Config, text: String) -> String {
match config { match config {
Config::Idefics2(_) => { Config::Idefics2(_) => {
@ -645,6 +808,10 @@ fn prepare_input<T: TokenizerTrait>(
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
// Add video regex
static VIDEO_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
@ -652,6 +819,53 @@ fn prepare_input<T: TokenizerTrait>(
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
// handle video content first
for chunk in VIDEO_RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let processed_video = match config {
Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) => {
let default_target_width = 224;
let default_target_height = 224;
fetch_video(
&inputs[chunk_start..chunk_end],
default_target_width,
default_target_height,
)?
}
Qwen2Vl(_) => {
let target_width = 360;
let target_height = 420;
fetch_video(&inputs[chunk_start..chunk_end], target_width, target_height)?
}
_ => {
unreachable!("Video tokens are not supported for this model configuration")
}
};
input_chunks.push(Chunk::Video(Video {
data: processed_video.frames.iter().flatten().cloned().collect(),
mimetype: processed_video.mimetype.clone(),
width: processed_video.width,
height: processed_video.height,
num_frames: processed_video.frames.len() as u32,
}));
let video_tokens = video_tokens(
config,
processed_video.height,
processed_video.width,
processed_video.sampled_frames as f32,
);
tokenizer_query.push_str(&video_tokens);
start = chunk_end;
}
// handle image content after video content
for chunk in RE.find_iter(&inputs) { for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
@ -660,7 +874,10 @@ fn prepare_input<T: TokenizerTrait>(
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype })); input_chunks.push(Chunk::Image(Image {
data,
mimetype: mimetype.clone(),
}));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
@ -683,7 +900,6 @@ fn prepare_input<T: TokenizerTrait>(
Ok((encoding, input_chunks)) Ok((encoding, input_chunks))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, bool, Option<usize>), (String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
@ -696,10 +912,28 @@ pub struct Image {
pub mimetype: String, pub mimetype: String,
} }
pub struct ProcessedVideo {
mimetype: String,
height: u32,
width: u32,
frames: Vec<Vec<u8>>, // RGB frames
sampled_frames: usize,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Video {
pub data: Vec<u8>,
pub mimetype: String,
pub width: u32,
pub height: u32,
pub num_frames: u32,
}
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk { pub enum Chunk {
Text(String), Text(String),
Image(Image), Image(Image),
Video(Video),
} }
/// Convert input chunks to a stringly-typed input for backwards /// Convert input chunks to a stringly-typed input for backwards
@ -718,6 +952,20 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data); let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
} }
Chunk::Video(Video {
data,
mimetype,
width,
height: _,
num_frames: _,
}) => {
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
let encoded = STANDARD.encode(data);
output.push_str(&format!(
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
width, mimetype, encoded, mimetype
));
}
}); });
output output
} }
@ -846,6 +1094,18 @@ pub enum ValidationError {
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")] #[error("{0} modality is not supported")]
UnsupportedModality(&'static str), UnsupportedModality(&'static str),
#[error("invalid video content: {0}")]
InvalidVideoContent(String),
#[error("could not parse MP4 file")]
MP4Error,
#[error("no video stream found")]
NoVideoStream,
#[error("io error: {0}")]
IoError(#[from] std::io::Error),
#[error("ffmpeg error")]
FFmpegError,
#[error("video not supported")]
VideoNotSupported,
} }
#[cfg(test)] #[cfg(test)]

View File

@ -81,6 +81,8 @@ class PaliGemmaForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_pixel_values: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1. # TODO This is odd but apparently pali gemma position ids start at 1.

View File

@ -751,6 +751,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_pixel_values: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:

View File

@ -181,6 +181,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_pixel_values: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:

View File

@ -411,6 +411,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self, self,
batch_input_ids: torch.Tensor, batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment # video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_input_ids.dim() == 1: if batch_input_ids.dim() == 1:
@ -424,8 +425,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
device=batch_input_ids.device, device=batch_input_ids.device,
) )
d = batch_input_ids.device d = batch_input_ids.device
if image_grid_thw is not None:
image_index = 0 # Handle both image and video tokens
if image_grid_thw is not None or video_grid_thw is not None:
vision_index = 0
llm_pos_ids_list = [] llm_pos_ids_list = []
for i, input_ids in enumerate(batch_input_ids): for i, input_ids in enumerate(batch_input_ids):
@ -433,24 +436,39 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids == self.vision_start_token_id input_ids == self.vision_start_token_id
).squeeze(1) ).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1] vision_tokens = input_ids[vision_start_indices + 1]
# only copy the sum of the image tokens GPU<->CPU
# only copy the sum of the image and video tokens GPU<->CPU
image_count = (vision_tokens == self.image_token_id).sum().item() image_count = (vision_tokens == self.image_token_id).sum().item()
video_count = (vision_tokens == self.video_token_id).sum().item()
current_pos = 0 current_pos = 0
for _ in range(image_count): for _ in range(image_count + video_count):
# copy the value position of the next image token from GPU<->CPU # copy the value position of the next image or video token from GPU<->CPU
next_image_pos = ( next_vision_pos = (
(input_ids[current_pos:] == self.image_token_id) (
(input_ids[current_pos:] == self.image_token_id)
| (input_ids[current_pos:] == self.video_token_id)
)
.nonzero()[0] .nonzero()[0]
.item() .item()
) )
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
time_steps, height, width = image_grid_thw[image_index].clone() is_video = (
input_ids[current_pos + next_vision_pos] == self.video_token_id
)
grid_thw = (
video_grid_thw[vision_index]
if is_video
else image_grid_thw[vision_index]
)
time_steps, height, width = grid_thw.clone()
height //= self.spatial_merge_size height //= self.spatial_merge_size
width //= self.spatial_merge_size width //= self.spatial_merge_size
# calculate the length of the text and image tokens # calculate the length of the text and image tokens
text_length = next_image_pos text_length = next_vision_pos - current_pos
start_idx = ( start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
) )
@ -460,7 +478,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids) llm_pos_ids_list.append(text_pos_ids)
# image position ids # vision position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave( t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width height * width
) )
@ -473,16 +491,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
height * time_steps height * time_steps
) )
image_pos_ids = ( vision_pos_ids = (
torch.stack([t_indices, h_indices, w_indices]) torch.stack([t_indices, h_indices, w_indices])
+ text_length + text_length
+ start_idx + start_idx
) )
llm_pos_ids_list.append(image_pos_ids) llm_pos_ids_list.append(vision_pos_ids)
current_pos += next_image_pos + time_steps * height * width current_pos = next_vision_pos + time_steps * height * width
image_index += 1 vision_index += 1
# Handle remaining text if any
if current_pos < batch_input_ids.size(1): if current_pos < batch_input_ids.size(1):
st_idx = ( st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
@ -515,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
video_pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None, pixel_attention_mask=None,
@ -525,13 +545,27 @@ class Qwen2VLForConditionalGeneration(nn.Module):
): ):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if video_pixel_values is not None and len(video_pixel_values) > 0:
vision_embeds = self.visual(
video_pixel_values, grid_thw=video_grid_thw
).squeeze(0)
vision_token_mask = input_ids == self.video_token_id
inputs_embeds[vision_token_mask] = vision_embeds
# apply the visual model to the pixel values if they are provided # apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None: vision_embeds = self.visual(
image_embeds = self.visual( pixel_values,
pixel_values, grid_thw=image_grid_thw grid_thw=(
).squeeze(0) torch.cat([image_grid_thw, video_grid_thw])
inputs_embeds[input_ids == self.image_token_id] = image_embeds if video_grid_thw is not None
else image_grid_thw
),
).squeeze(0)
# Apply embeddings to image tokens
vision_token_mask = input_ids == self.image_token_id
inputs_embeds[vision_token_mask] = vision_embeds
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,

View File

@ -148,7 +148,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
if image_inputs is not None: if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0] assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs video_inputs = None
return batch_tokenized_inputs, image_inputs, video_inputs
@classmethod @classmethod
def from_pb_processor( def from_pb_processor(
@ -160,8 +161,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "VlmCausalLMBatch": ) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs, _video_inputs = (
pb.requests, tokenizer, processor, config cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
) )
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors. # XXX: <|image|> token is actually out of bounds and bugs out the logit processors.

View File

@ -68,4 +68,6 @@ class PaliGemmaBatch(VlmCausalLMBatch):
image_inputs = new_image_inputs image_inputs = new_image_inputs
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs
video_inputs = None
return batch_tokenized_inputs, image_inputs, video_inputs

View File

@ -2,6 +2,7 @@ import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
import numpy as np
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -17,6 +18,7 @@ from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged from text_generation_server.models.metadata_kernels import block_tables_to_ragged
import math
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -76,6 +78,15 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def video_text_replacement(processor, video_input, config) -> str:
if config.model_type == "qwen2_vl":
num_pads = video_input.pixel_values.shape[0] // 4
padding = "<|video_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str: def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
return text.replace( return text.replace(
@ -138,29 +149,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def smart_nframes(
fps: int,
nframes: int,
min_frames: int,
max_frames: int,
total_frames: int,
video_fps: int | float,
) -> int:
if nframes:
nframes = round(nframes / 2) * 2
else:
min_frames = math.ceil(min_frames / 2) * 2
max_frames = math.floor(max_frames / 2) * 2
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round(nframes / 2) * 2
if not (2 <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
)
return nframes
class VlmCausalLMBatch(FlashCausalLMBatch): class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
video_pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor] image_grid_thw: Optional[torch.Tensor]
video_grid_thw: Optional[torch.Tensor]
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
batch.video_grid_thw = None
return batch return batch
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
batch.video_grid_thw = None
return batch return batch
@classmethod @classmethod
@ -171,6 +212,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
# can make the image splits the same size. And we need the final # can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens. # sizes to insert correct number of image tokens.
images = [] images = []
videos = []
for r in requests: for r in requests:
for chunk in r.input_chunks.chunks: for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk") chunk_type = chunk.WhichOneof("chunk")
@ -190,6 +232,30 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append(image) images.append(image)
else: else:
images.append([image]) images.append([image])
elif chunk_type == "video":
if config.model_type == "qwen2_vl":
video_frame_buf = np.frombuffer(
chunk.video.data, dtype=np.uint8
)
num_bytes = len(video_frame_buf)
bytes_per_frame = num_bytes // chunk.video.frames
# iterate over with a stride the size of a frame
frames = []
for i in range(chunk.video.frames):
frame = video_frame_buf[
i * bytes_per_frame : (i + 1) * bytes_per_frame
]
frame = frame.reshape(
chunk.video.height, chunk.video.width, 3
)
frames.append(frame)
video_frame_buf = np.stack(frames)
frame_nchw_tensor = torch.from_numpy(video_frame_buf).permute(
0, 3, 1, 2
)
videos.append(frame_nchw_tensor)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -198,6 +264,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
else: else:
image_inputs = None image_inputs = None
video_inputs = None
if videos:
try:
video_inputs = processor.image_processor(
videos,
return_tensors="pt",
)
except Exception as e:
print(f"Failed to process video: {e}")
pass
else:
video_inputs = None
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
image_id = 0 image_id = 0
@ -212,9 +291,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
processor, image_inputs, config, image_id processor, image_inputs, config, image_id
) )
image_id += 1 image_id += 1
elif chunk_type == "video":
full_text += video_text_replacement(processor, video_inputs, config)
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
@ -225,7 +305,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
add_special_tokens=not config.model_type == "paligemma", add_special_tokens=not config.model_type == "paligemma",
)["input_ids"] )["input_ids"]
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs, video_inputs
@classmethod @classmethod
def from_pb_processor( def from_pb_processor(
@ -237,10 +317,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "VlmCausalLMBatch": ) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config pb.requests, tokenizer, processor, config
) )
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if video_inputs is not None:
if "pixel_values" in video_inputs:
batch.video_pixel_values = video_inputs["pixel_values"].to(
device=device
)
if "image_grid_thw" in video_inputs:
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else:
batch.video_pixel_values = None
batch.video_grid_thw = None
if image_inputs is not None: if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs: if "pixel_attention_mask" in image_inputs:
@ -257,6 +350,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else: else:
batch.image_grid_thw = None batch.image_grid_thw = None
if "video_grid_thw" in image_inputs:
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else: else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
@ -367,7 +464,7 @@ class VlmCausalLM(FlashCausalLM):
if self.model.config.model_type == "qwen2_vl": if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1 and batch.prefilling: if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids( position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw input_ids, batch.image_grid_thw, batch.video_grid_thw
) )
batch.position_ids = position_ids batch.position_ids = position_ids
@ -420,20 +517,26 @@ class VlmCausalLM(FlashCausalLM):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
video_pixel_values=batch.video_pixel_values,
pixel_attention_mask=batch.pixel_attention_mask, pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes, image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw, image_grid_thw=batch.image_grid_thw,
video_grid_thw=batch.video_grid_thw,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.video_pixel_values is not None:
batch.video_pixel_values = None
if batch.pixel_attention_mask is not None: if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
if batch.image_sizes is not None: if batch.image_sizes is not None:
batch.image_sizes = None batch.image_sizes = None
if batch.image_grid_thw is not None: if batch.image_grid_thw is not None:
batch.image_grid_thw = None batch.image_grid_thw = None
if batch.video_grid_thw is not None:
batch.video_grid_thw = None
return logits, speculative_logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph