mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge b27749eba7
into afb6c728d8
This commit is contained in:
commit
33299d1ee9
2
.github/workflows/autodocs.yaml
vendored
2
.github/workflows/autodocs.yaml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Protocol Buffers compiler
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev
|
||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev clang libavcodec-dev libavfilter-dev libavdevice-dev libavformat-dev libavutil-dev pkg-config
|
||||
|
||||
- name: Install Launcher
|
||||
id: install-launcher
|
||||
|
4
.github/workflows/tests.yaml
vendored
4
.github/workflows/tests.yaml
vendored
@ -43,7 +43,9 @@ jobs:
|
||||
- name: Install
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install python3.11-dev -y
|
||||
sudo apt install python3.11-dev python3.11-venv python3-pip clang libavcodec-dev libavfilter-dev libavdevice-dev libavformat-dev libavutil-dev pkg-config -y
|
||||
export PKG_CONFIG_PATH=$PKG_CONFIG_PATH:/usr/lib/x86_64-linux-gnu/pkgconfig
|
||||
python -m pip install --upgrade pip
|
||||
make install-cpu
|
||||
- name: Run server tests
|
||||
run: |
|
||||
|
89
Cargo.lock
generated
89
Cargo.lock
generated
@ -267,7 +267,7 @@ version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"bindgen 0.69.5",
|
||||
"cc",
|
||||
"cmake",
|
||||
"dunce",
|
||||
@ -454,6 +454,24 @@ dependencies = [
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn 2.0.89",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
@ -487,6 +505,15 @@ version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||
|
||||
[[package]]
|
||||
name = "bitreader"
|
||||
version = "0.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "886559b1e163d56c765bc3a985febb4eee8009f625244511d8ee3c432e08c066"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitstream-io"
|
||||
version = "2.6.0"
|
||||
@ -1194,6 +1221,15 @@ dependencies = [
|
||||
"zune-inflate",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible_collections"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd"
|
||||
dependencies = [
|
||||
"hashbrown 0.13.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.11.0"
|
||||
@ -1219,6 +1255,31 @@ dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ffmpeg-next"
|
||||
version = "7.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da02698288e0275e442a47fc12ca26d50daf0d48b15398ba5906f20ac2e2a9f9"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"ffmpeg-sys-next",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ffmpeg-sys-next"
|
||||
version = "7.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2bc3234d0a4b2f7d083699d0860c6c9dd83713908771b60f94a96f8704adfe45"
|
||||
dependencies = [
|
||||
"bindgen 0.70.1",
|
||||
"cc",
|
||||
"libc",
|
||||
"num_cpus",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
@ -1512,6 +1573,15 @@ version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
@ -2471,6 +2541,20 @@ dependencies = [
|
||||
"syn 2.0.89",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mp4parse"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570"
|
||||
dependencies = [
|
||||
"bitreader",
|
||||
"byteorder",
|
||||
"fallible_collections",
|
||||
"log",
|
||||
"num-traits",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.10.0"
|
||||
@ -4425,6 +4509,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"clap 4.5.21",
|
||||
"csv",
|
||||
"ffmpeg-next",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"hf-hub",
|
||||
@ -4436,6 +4521,7 @@ dependencies = [
|
||||
"metrics-exporter-prometheus",
|
||||
"minijinja",
|
||||
"minijinja-contrib",
|
||||
"mp4parse",
|
||||
"ngrok",
|
||||
"nohash-hasher",
|
||||
"once_cell",
|
||||
@ -4449,6 +4535,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sysinfo",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
73
Dockerfile
73
Dockerfile
@ -20,6 +20,20 @@ FROM chef AS builder
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
ffmpeg \
|
||||
libavcodec-dev \
|
||||
libavfilter-dev \
|
||||
libavdevice-dev \
|
||||
libavformat-dev \
|
||||
libavutil-dev \
|
||||
libswscale-dev \
|
||||
pkg-config \
|
||||
libclang-dev \
|
||||
clang \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
@ -27,7 +41,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --features video --recipe-path recipe.json
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
@ -40,7 +54,7 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
RUN cargo build --profile release-opt --frozen --features video
|
||||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
@ -61,18 +75,18 @@ ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
ccache \
|
||||
curl \
|
||||
git && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
ccache \
|
||||
curl \
|
||||
git && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install conda
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
@ -82,12 +96,15 @@ RUN chmod +x ~/mambaforge.sh && \
|
||||
# Install pytorch
|
||||
# On arm64 we exit with an error code
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" "openssl>=3.3.0" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
RUN /opt/conda/bin/conda install -y pyOpenSSL
|
||||
|
||||
|
||||
# CUDA kernels builder image
|
||||
FROM pytorch-install AS kernel-builder
|
||||
|
||||
@ -95,8 +112,8 @@ ARG MAX_JOBS=8
|
||||
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
ninja-build cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
ninja-build cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Build Flash Attention CUDA kernels
|
||||
FROM kernel-builder AS flash-att-builder
|
||||
@ -188,12 +205,15 @@ ENV HF_HOME=/data \
|
||||
WORKDIR /usr/src
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
libssl-dev \
|
||||
ca-certificates \
|
||||
make \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
libssl-dev \
|
||||
ca-certificates \
|
||||
make \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add ffmpeg libraries to the path
|
||||
ENV LD_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
||||
|
||||
# Copy conda with PyTorch installed
|
||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
||||
@ -239,6 +259,8 @@ RUN cd server && \
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
# Required to find libpython within the rust binaries
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||
ENV LD_PRELOAD="/opt/conda/lib/libcrypto.so.3"
|
||||
|
||||
# This is needed because exl2 tries to load flash-attn
|
||||
# And fails with our builds.
|
||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||
@ -247,9 +269,9 @@ ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||
# The binaries change on every build given we burn the SHA into them
|
||||
# The deps change less often.
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
build-essential \
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
@ -258,6 +280,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# Copy the ffmpeg libraries
|
||||
COPY --from=builder /usr/lib/x86_64-linux-gnu/* /usr/lib/x86_64-linux-gnu-copy/
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu-copy"
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base AS sagemaker
|
||||
|
@ -9,7 +9,7 @@ use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
||||
pub use v3::{Chunk, Image, Input, InputChunk, Video};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
@ -79,6 +79,20 @@ impl ChunksToString for Vec<InputChunk> {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("", mimetype, encoded))
|
||||
}
|
||||
Some(Chunk::Video(Video {
|
||||
data,
|
||||
mimetype,
|
||||
width,
|
||||
height: _,
|
||||
frames: _,
|
||||
})) => {
|
||||
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!(
|
||||
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
|
||||
width, mimetype, encoded, mimetype
|
||||
));
|
||||
}
|
||||
// We don't create empty chunks, so this should be unreachable.
|
||||
None => unreachable!("Chunks should never be empty"),
|
||||
});
|
||||
|
@ -8,6 +8,6 @@ pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, Tokens,
|
||||
StoppingCriteriaParameters, Tokens, Video,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
@ -301,6 +301,7 @@ impl TensorRtLlmBackendV2 {
|
||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||
Chunk::Text(_) => Ok(()),
|
||||
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||
Chunk::Video(_) => Err(ValidationError(UnsupportedModality("video"))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ pub use grpc_client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
StoppingCriteriaParameters, Video,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
|
@ -439,6 +439,13 @@ impl State {
|
||||
data: image.data,
|
||||
mimetype: image.mimetype,
|
||||
}),
|
||||
Chunk::Video(video) => client::Chunk::Video(client::Video {
|
||||
data: video.data,
|
||||
mimetype: video.mimetype,
|
||||
width: video.width,
|
||||
height: video.height,
|
||||
frames: video.num_frames,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
|
@ -1922,6 +1922,24 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"video_url",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"video_url"
|
||||
]
|
||||
},
|
||||
"video_url": {
|
||||
"$ref": "#/components/schemas/Url"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
|
12
flake.nix
12
flake.nix
@ -115,15 +115,17 @@
|
||||
buildInputs =
|
||||
[
|
||||
benchmark
|
||||
launcher
|
||||
router
|
||||
server
|
||||
cargo
|
||||
client
|
||||
clippy
|
||||
ffmpeg
|
||||
launcher
|
||||
openssl.dev
|
||||
pkg-config
|
||||
cargo
|
||||
router
|
||||
rustPlatform.bindgenHook
|
||||
rustfmt
|
||||
clippy
|
||||
server
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
docker
|
||||
|
@ -0,0 +1,19 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1733450914,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
75
integration-tests/models/test_flash_qwen2_vl_video.py
Normal file
75
integration-tests/models/test_flash_qwen2_vl_video.py
Normal file
@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
import json
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen2_vl_handle(launcher):
|
||||
with launcher(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
max_input_length=10_000,
|
||||
max_batch_prefill_tokens=10_000,
|
||||
max_total_tokens=10_001,
|
||||
cuda_graphs=[0],
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def qwen2_vl(qwen2_vl_handle):
|
||||
await qwen2_vl_handle.health(300)
|
||||
return qwen2_vl_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qwen2_vl_simpl(qwen2_vl, response_snapshot):
|
||||
responses = requests.post(
|
||||
f"{qwen2_vl.base_url}/v1/chat/completions",
|
||||
headers=qwen2_vl.headers,
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this video.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"seed": 42,
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
|
||||
# iterate over the response in chunks
|
||||
count = 0
|
||||
full_text = ""
|
||||
last_response = None
|
||||
for chunk in responses.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
count += 1
|
||||
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
|
||||
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
|
||||
for line in lines:
|
||||
if line == "[DONE]":
|
||||
break
|
||||
print("=", line)
|
||||
try:
|
||||
response = json.loads(line)
|
||||
# print(response)
|
||||
last_response = response
|
||||
full_text += response["choices"][0]["delta"]["content"]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
assert last_response == response_snapshot
|
@ -19,6 +19,26 @@ defaultCrateOverrides
|
||||
};
|
||||
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
||||
|
||||
ffmpeg-sys-next = attrs: {
|
||||
nativeBuildInputs = [
|
||||
pkg-config
|
||||
];
|
||||
buildInputs = [
|
||||
rustPlatform.bindgenHook
|
||||
ffmpeg
|
||||
];
|
||||
};
|
||||
|
||||
ffmpeg-next = attrs: {
|
||||
# Somehow the variables that are passed are mangled, so they are not
|
||||
# correctly passed to the ffmpeg-next build script. Worth investigating
|
||||
# more since it's probably a bug in crate2nix or buildRustCrate.
|
||||
postPatch = ''
|
||||
substituteInPlace build.rs \
|
||||
--replace-fail "DEP_FFMPEG_" "DEP_FFMPEG_SYS_NEXT_"
|
||||
'';
|
||||
};
|
||||
|
||||
grpc-metadata = attrs: {
|
||||
src = filter {
|
||||
root = ../backends/grpc-metadata;
|
||||
|
@ -5,9 +5,11 @@
|
||||
cmake,
|
||||
isort,
|
||||
ninja,
|
||||
rustPlatform,
|
||||
which,
|
||||
cudaPackages,
|
||||
openssl,
|
||||
ffmpeg,
|
||||
pkg-config,
|
||||
poetry,
|
||||
protobuf,
|
||||
@ -26,6 +28,7 @@
|
||||
mkShell {
|
||||
nativeBuildInputs =
|
||||
[
|
||||
rustPlatform.bindgenHook
|
||||
black
|
||||
isort
|
||||
pkg-config
|
||||
@ -53,6 +56,7 @@ mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
openssl.dev
|
||||
ffmpeg
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
|
@ -64,12 +64,31 @@ message Image {
|
||||
string mimetype = 2;
|
||||
}
|
||||
|
||||
message Video {
|
||||
/// Binary video data (array of RGB data)
|
||||
bytes data = 1;
|
||||
|
||||
/// Video MIME type.
|
||||
string mimetype = 2;
|
||||
|
||||
/// Video width
|
||||
uint32 width = 3;
|
||||
|
||||
/// Video height
|
||||
uint32 height = 4;
|
||||
|
||||
/// Total number of frames
|
||||
uint32 frames = 5;
|
||||
}
|
||||
|
||||
message InputChunk {
|
||||
oneof chunk {
|
||||
/// Plain text data
|
||||
string text = 1;
|
||||
/// Image data
|
||||
Image image = 2;
|
||||
/// Video URLs
|
||||
Video video = 3;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,20 +14,23 @@ async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
ffmpeg-next = { version = "7.1.0", optional = true }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
itertools = "0.10"
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
mp4parse = { version = "0.17.0", optional = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
tempfile = { version = "3.10.1", optional = true }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
@ -74,3 +77,4 @@ default = ["ngrok"]
|
||||
ngrok = ["dep:ngrok"]
|
||||
google = []
|
||||
kserve = []
|
||||
video = ["ffmpeg-next", "mp4parse", "tempfile"]
|
||||
|
@ -1173,6 +1173,7 @@ pub struct Url {
|
||||
pub enum MessageChunk {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: Url },
|
||||
VideoUrl { video_url: Url },
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
@ -1229,6 +1230,9 @@ impl From<Message> for TextMessage {
|
||||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||
MessageChunk::VideoUrl { video_url } => {
|
||||
format!("<video>({})", video_url.url)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
|
@ -22,6 +22,15 @@ use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
use {once_cell::sync::Lazy, regex::Regex};
|
||||
|
||||
#[cfg(feature = "video")]
|
||||
use ffmpeg_next::{
|
||||
format::Pixel,
|
||||
media::Type,
|
||||
software::scaling::{context::Context, flag::Flags},
|
||||
};
|
||||
#[cfg(feature = "video")]
|
||||
use std::io::Write;
|
||||
|
||||
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
||||
|
||||
/// Validation
|
||||
@ -536,6 +545,140 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "video"))]
|
||||
pub fn fetch_video(
|
||||
_input: &str,
|
||||
_target_width: u32,
|
||||
_target_height: u32,
|
||||
) -> Result<ProcessedVideo, ValidationError> {
|
||||
Err(ValidationError::VideoNotSupported)
|
||||
}
|
||||
|
||||
#[cfg(feature = "video")]
|
||||
pub fn fetch_video(
|
||||
input: &str,
|
||||
target_width: u32,
|
||||
target_height: u32,
|
||||
) -> Result<ProcessedVideo, ValidationError> {
|
||||
let (data, mimetype) =
|
||||
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
||||
let url = &input["<video>(".len()..input.len() - 1];
|
||||
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
|
||||
(data, "video/mp4".to_string())
|
||||
} else if input.starts_with("<video>(data:") {
|
||||
let content = &input["<video>(data:".len()..input.len() - 1];
|
||||
let tokens: Vec<&str> = content.split(';').collect();
|
||||
if tokens.len() != 2 {
|
||||
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||
}
|
||||
let mimetype = tokens[0];
|
||||
let content = tokens[1];
|
||||
if !content.starts_with("base64,") {
|
||||
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||
}
|
||||
let data = STANDARD.decode(&content["base64,".len()..])?;
|
||||
(data, mimetype.to_string())
|
||||
} else {
|
||||
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
||||
};
|
||||
|
||||
// init ffmpeg
|
||||
ffmpeg_next::init().map_err(|_| ValidationError::FFmpegError)?;
|
||||
|
||||
// create temporary file for ffmpeg input
|
||||
let mut temp_file = tempfile::NamedTempFile::new().map_err(ValidationError::IoError)?;
|
||||
temp_file
|
||||
.write_all(&data)
|
||||
.map_err(ValidationError::IoError)?;
|
||||
|
||||
let mut ictx =
|
||||
ffmpeg_next::format::input(&temp_file.path()).map_err(|_| ValidationError::FFmpegError)?;
|
||||
|
||||
let input = ictx
|
||||
.streams()
|
||||
.best(Type::Video)
|
||||
.ok_or(ValidationError::FFmpegError)?;
|
||||
let video_stream_index = input.index();
|
||||
|
||||
let context_decoder = ffmpeg_next::codec::context::Context::from_parameters(input.parameters())
|
||||
.map_err(|_| ValidationError::FFmpegError)?;
|
||||
let mut decoder = context_decoder
|
||||
.decoder()
|
||||
.video()
|
||||
.map_err(|_| ValidationError::FFmpegError)?;
|
||||
|
||||
let width = target_width;
|
||||
let height = target_height;
|
||||
|
||||
let mut scaler = Context::get(
|
||||
decoder.format(),
|
||||
decoder.width(), // original width
|
||||
decoder.height(),
|
||||
Pixel::RGB24,
|
||||
width, // target width
|
||||
height,
|
||||
Flags::BILINEAR,
|
||||
)
|
||||
.map_err(|_| ValidationError::FFmpegError)?;
|
||||
|
||||
let mut frame_index = 0;
|
||||
let mut captured_frame_index = 0;
|
||||
let mut frames = vec![];
|
||||
|
||||
let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg_next::decoder::Video,
|
||||
raw_fps: f32|
|
||||
-> Result<(), ffmpeg_next::Error> {
|
||||
let mut decoded = ffmpeg_next::util::frame::video::Video::empty();
|
||||
let fps = raw_fps.floor();
|
||||
while decoder.receive_frame(&mut decoded).is_ok() {
|
||||
let mut rgb_frame = ffmpeg_next::util::frame::video::Video::empty();
|
||||
scaler.run(&decoded, &mut rgb_frame)?;
|
||||
if frame_index as f32 % fps == 0.0 {
|
||||
captured_frame_index += 1;
|
||||
// Create new buffer without padding
|
||||
let mut frame_data =
|
||||
Vec::with_capacity((rgb_frame.width() * rgb_frame.height() * 3) as usize);
|
||||
let src_data = rgb_frame.data(0);
|
||||
let row_size = rgb_frame.width() as usize * 3;
|
||||
|
||||
// Copy each row without padding
|
||||
for y in 0..rgb_frame.height() as usize {
|
||||
let start = y * rgb_frame.stride(0);
|
||||
let end = start + row_size;
|
||||
frame_data.extend_from_slice(&src_data[start..end]);
|
||||
}
|
||||
frames.push(frame_data);
|
||||
}
|
||||
frame_index += 1;
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
for (stream, packet) in ictx.packets() {
|
||||
// Floor the fps to get a whole number
|
||||
let fps = (stream.rate().numerator() as f32 / stream.rate().denominator() as f32).floor();
|
||||
|
||||
if stream.index() == video_stream_index {
|
||||
decoder
|
||||
.send_packet(&packet)
|
||||
.map_err(|_| ValidationError::FFmpegError)?;
|
||||
receive_and_process_decoded_frames(&mut decoder, fps)
|
||||
.map_err(|_| ValidationError::FFmpegError)?;
|
||||
}
|
||||
}
|
||||
decoder
|
||||
.send_eof()
|
||||
.map_err(|_| ValidationError::FFmpegError)?;
|
||||
|
||||
Ok(ProcessedVideo {
|
||||
mimetype,
|
||||
height,
|
||||
width,
|
||||
frames,
|
||||
sampled_frames: captured_frame_index,
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||
if input.starts_with(" || input.starts_with(" {
|
||||
let url = &input["..input.len() - 1];
|
||||
@ -624,6 +767,26 @@ fn image_tokens(
|
||||
}
|
||||
}
|
||||
|
||||
fn video_tokens(config: &Config, height: u32, width: u32, sampled_frames: f32) -> String {
|
||||
use Config::*;
|
||||
|
||||
match config {
|
||||
Qwen2Vl(_config) => {
|
||||
let min_frames = 2_f32;
|
||||
let max_frames = 256_f32;
|
||||
// make sure the frames are within the range and are even
|
||||
let nframes = (sampled_frames).max(min_frames).min(max_frames);
|
||||
let nframes = (nframes / 2.0).round() as usize * 2;
|
||||
let num_tokens = nframes * height as usize * width as usize / 1541;
|
||||
format!(
|
||||
"<|vision_start|>{:?}<|vision_end|>",
|
||||
"<|video_pad|>".repeat(num_tokens)
|
||||
)
|
||||
}
|
||||
_ => unimplemented!("Video tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
match config {
|
||||
Config::Idefics2(_) => {
|
||||
@ -645,6 +808,10 @@ fn prepare_input<T: TokenizerTrait>(
|
||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
// Add video regex
|
||||
static VIDEO_RE: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
||||
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(
|
||||
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
|
||||
@ -652,6 +819,53 @@ fn prepare_input<T: TokenizerTrait>(
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
||||
// handle video content first
|
||||
for chunk in VIDEO_RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let processed_video = match config {
|
||||
Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) => {
|
||||
let default_target_width = 224;
|
||||
let default_target_height = 224;
|
||||
fetch_video(
|
||||
&inputs[chunk_start..chunk_end],
|
||||
default_target_width,
|
||||
default_target_height,
|
||||
)?
|
||||
}
|
||||
Qwen2Vl(_) => {
|
||||
let target_width = 360;
|
||||
let target_height = 420;
|
||||
fetch_video(&inputs[chunk_start..chunk_end], target_width, target_height)?
|
||||
}
|
||||
_ => {
|
||||
unreachable!("Video tokens are not supported for this model configuration")
|
||||
}
|
||||
};
|
||||
|
||||
input_chunks.push(Chunk::Video(Video {
|
||||
data: processed_video.frames.iter().flatten().cloned().collect(),
|
||||
mimetype: processed_video.mimetype.clone(),
|
||||
width: processed_video.width,
|
||||
height: processed_video.height,
|
||||
num_frames: processed_video.frames.len() as u32,
|
||||
}));
|
||||
let video_tokens = video_tokens(
|
||||
config,
|
||||
processed_video.height,
|
||||
processed_video.width,
|
||||
processed_video.sampled_frames as f32,
|
||||
);
|
||||
tokenizer_query.push_str(&video_tokens);
|
||||
start = chunk_end;
|
||||
}
|
||||
|
||||
// handle image content after video content
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
@ -660,7 +874,10 @@ fn prepare_input<T: TokenizerTrait>(
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||
input_chunks.push(Chunk::Image(Image {
|
||||
data,
|
||||
mimetype: mimetype.clone(),
|
||||
}));
|
||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||
start = chunk_end;
|
||||
}
|
||||
@ -683,7 +900,6 @@ fn prepare_input<T: TokenizerTrait>(
|
||||
|
||||
Ok((encoding, input_chunks))
|
||||
}
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, bool, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||
@ -696,10 +912,28 @@ pub struct Image {
|
||||
pub mimetype: String,
|
||||
}
|
||||
|
||||
pub struct ProcessedVideo {
|
||||
mimetype: String,
|
||||
height: u32,
|
||||
width: u32,
|
||||
frames: Vec<Vec<u8>>, // RGB frames
|
||||
sampled_frames: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct Video {
|
||||
pub data: Vec<u8>,
|
||||
pub mimetype: String,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub num_frames: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Chunk {
|
||||
Text(String),
|
||||
Image(Image),
|
||||
Video(Video),
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
@ -718,6 +952,20 @@ impl ChunksToString for Vec<Chunk> {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("", mimetype, encoded))
|
||||
}
|
||||
Chunk::Video(Video {
|
||||
data,
|
||||
mimetype,
|
||||
width,
|
||||
height: _,
|
||||
num_frames: _,
|
||||
}) => {
|
||||
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!(
|
||||
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
|
||||
width, mimetype, encoded, mimetype
|
||||
));
|
||||
}
|
||||
});
|
||||
output
|
||||
}
|
||||
@ -846,6 +1094,18 @@ pub enum ValidationError {
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
#[error("{0} modality is not supported")]
|
||||
UnsupportedModality(&'static str),
|
||||
#[error("invalid video content: {0}")]
|
||||
InvalidVideoContent(String),
|
||||
#[error("could not parse MP4 file")]
|
||||
MP4Error,
|
||||
#[error("no video stream found")]
|
||||
NoVideoStream,
|
||||
#[error("io error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("ffmpeg error")]
|
||||
FFmpegError,
|
||||
#[error("video not supported")]
|
||||
VideoNotSupported,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -81,6 +81,8 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_pixel_values: Optional[torch.FloatTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
@ -751,6 +751,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_pixel_values: Optional[torch.FloatTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
|
@ -181,6 +181,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_pixel_values: Optional[torch.FloatTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
|
@ -411,6 +411,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
batch_input_ids: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if batch_input_ids.dim() == 1:
|
||||
@ -424,8 +425,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
device=batch_input_ids.device,
|
||||
)
|
||||
d = batch_input_ids.device
|
||||
if image_grid_thw is not None:
|
||||
image_index = 0
|
||||
|
||||
# Handle both image and video tokens
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
vision_index = 0
|
||||
llm_pos_ids_list = []
|
||||
|
||||
for i, input_ids in enumerate(batch_input_ids):
|
||||
@ -433,24 +436,39 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
input_ids == self.vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
# only copy the sum of the image tokens GPU<->CPU
|
||||
|
||||
# only copy the sum of the image and video tokens GPU<->CPU
|
||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
||||
video_count = (vision_tokens == self.video_token_id).sum().item()
|
||||
|
||||
current_pos = 0
|
||||
for _ in range(image_count):
|
||||
# copy the value position of the next image token from GPU<->CPU
|
||||
next_image_pos = (
|
||||
(input_ids[current_pos:] == self.image_token_id)
|
||||
for _ in range(image_count + video_count):
|
||||
# copy the value position of the next image or video token from GPU<->CPU
|
||||
next_vision_pos = (
|
||||
(
|
||||
(input_ids[current_pos:] == self.image_token_id)
|
||||
| (input_ids[current_pos:] == self.video_token_id)
|
||||
)
|
||||
.nonzero()[0]
|
||||
.item()
|
||||
)
|
||||
|
||||
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
||||
time_steps, height, width = image_grid_thw[image_index].clone()
|
||||
is_video = (
|
||||
input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||
)
|
||||
grid_thw = (
|
||||
video_grid_thw[vision_index]
|
||||
if is_video
|
||||
else image_grid_thw[vision_index]
|
||||
)
|
||||
|
||||
time_steps, height, width = grid_thw.clone()
|
||||
height //= self.spatial_merge_size
|
||||
width //= self.spatial_merge_size
|
||||
|
||||
# calculate the length of the text and image tokens
|
||||
text_length = next_image_pos
|
||||
text_length = next_vision_pos - current_pos
|
||||
start_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
)
|
||||
@ -460,7 +478,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
||||
llm_pos_ids_list.append(text_pos_ids)
|
||||
|
||||
# image position ids
|
||||
# vision position ids
|
||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
||||
height * width
|
||||
)
|
||||
@ -473,16 +491,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
height * time_steps
|
||||
)
|
||||
|
||||
image_pos_ids = (
|
||||
vision_pos_ids = (
|
||||
torch.stack([t_indices, h_indices, w_indices])
|
||||
+ text_length
|
||||
+ start_idx
|
||||
)
|
||||
llm_pos_ids_list.append(image_pos_ids)
|
||||
llm_pos_ids_list.append(vision_pos_ids)
|
||||
|
||||
current_pos += next_image_pos + time_steps * height * width
|
||||
image_index += 1
|
||||
current_pos = next_vision_pos + time_steps * height * width
|
||||
vision_index += 1
|
||||
|
||||
# Handle remaining text if any
|
||||
if current_pos < batch_input_ids.size(1):
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
@ -515,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
video_pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
@ -525,13 +545,27 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if video_pixel_values is not None and len(video_pixel_values) > 0:
|
||||
vision_embeds = self.visual(
|
||||
video_pixel_values, grid_thw=video_grid_thw
|
||||
).squeeze(0)
|
||||
vision_token_mask = input_ids == self.video_token_id
|
||||
inputs_embeds[vision_token_mask] = vision_embeds
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
vision_embeds = self.visual(
|
||||
pixel_values,
|
||||
grid_thw=(
|
||||
torch.cat([image_grid_thw, video_grid_thw])
|
||||
if video_grid_thw is not None
|
||||
else image_grid_thw
|
||||
),
|
||||
).squeeze(0)
|
||||
|
||||
# Apply embeddings to image tokens
|
||||
vision_token_mask = input_ids == self.image_token_id
|
||||
inputs_embeds[vision_token_mask] = vision_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -148,7 +148,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
if image_inputs is not None:
|
||||
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
video_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
@ -160,8 +161,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
batch_tokenized_inputs, image_inputs, _video_inputs = (
|
||||
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
|
||||
|
@ -68,4 +68,6 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
image_inputs = new_image_inputs
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
video_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
@ -17,6 +18,7 @@ from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
import math
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -76,6 +78,15 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def video_text_replacement(processor, video_input, config) -> str:
|
||||
if config.model_type == "qwen2_vl":
|
||||
num_pads = video_input.pixel_values.shape[0] // 4
|
||||
padding = "<|video_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def image_text_replacement_fixup(config, text: str) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
return text.replace(
|
||||
@ -138,29 +149,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def smart_nframes(
|
||||
fps: int,
|
||||
nframes: int,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
total_frames: int,
|
||||
video_fps: int | float,
|
||||
) -> int:
|
||||
if nframes:
|
||||
nframes = round(nframes / 2) * 2
|
||||
else:
|
||||
min_frames = math.ceil(min_frames / 2) * 2
|
||||
max_frames = math.floor(max_frames / 2) * 2
|
||||
nframes = total_frames / video_fps * fps
|
||||
nframes = min(max(nframes, min_frames), max_frames)
|
||||
nframes = round(nframes / 2) * 2
|
||||
if not (2 <= nframes and nframes <= total_frames):
|
||||
raise ValueError(
|
||||
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
|
||||
)
|
||||
return nframes
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
video_pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
video_grid_thw: Optional[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.video_pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.video_grid_thw = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.video_pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.video_grid_thw = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
@ -171,6 +212,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
# can make the image splits the same size. And we need the final
|
||||
# sizes to insert correct number of image tokens.
|
||||
images = []
|
||||
videos = []
|
||||
for r in requests:
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
@ -190,6 +232,30 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
images.append(image)
|
||||
else:
|
||||
images.append([image])
|
||||
elif chunk_type == "video":
|
||||
if config.model_type == "qwen2_vl":
|
||||
video_frame_buf = np.frombuffer(
|
||||
chunk.video.data, dtype=np.uint8
|
||||
)
|
||||
num_bytes = len(video_frame_buf)
|
||||
bytes_per_frame = num_bytes // chunk.video.frames
|
||||
|
||||
# iterate over with a stride the size of a frame
|
||||
frames = []
|
||||
for i in range(chunk.video.frames):
|
||||
frame = video_frame_buf[
|
||||
i * bytes_per_frame : (i + 1) * bytes_per_frame
|
||||
]
|
||||
frame = frame.reshape(
|
||||
chunk.video.height, chunk.video.width, 3
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
video_frame_buf = np.stack(frames)
|
||||
frame_nchw_tensor = torch.from_numpy(video_frame_buf).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
videos.append(frame_nchw_tensor)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
@ -198,6 +264,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
else:
|
||||
image_inputs = None
|
||||
|
||||
video_inputs = None
|
||||
if videos:
|
||||
try:
|
||||
video_inputs = processor.image_processor(
|
||||
videos,
|
||||
return_tensors="pt",
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to process video: {e}")
|
||||
pass
|
||||
else:
|
||||
video_inputs = None
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
image_id = 0
|
||||
@ -212,9 +291,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
image_id += 1
|
||||
elif chunk_type == "video":
|
||||
full_text += video_text_replacement(processor, video_inputs, config)
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
@ -225,7 +305,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
add_special_tokens=not config.model_type == "paligemma",
|
||||
)["input_ids"]
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
@ -237,10 +317,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if video_inputs is not None:
|
||||
if "pixel_values" in video_inputs:
|
||||
batch.video_pixel_values = video_inputs["pixel_values"].to(
|
||||
device=device
|
||||
)
|
||||
if "image_grid_thw" in video_inputs:
|
||||
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.video_grid_thw = None
|
||||
else:
|
||||
batch.video_pixel_values = None
|
||||
batch.video_grid_thw = None
|
||||
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
@ -257,6 +350,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.image_grid_thw = None
|
||||
if "video_grid_thw" in image_inputs:
|
||||
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.video_grid_thw = None
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
@ -367,7 +464,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
if self.model.config.model_type == "qwen2_vl":
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids, batch.image_grid_thw
|
||||
input_ids, batch.image_grid_thw, batch.video_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
@ -420,20 +517,26 @@ class VlmCausalLM(FlashCausalLM):
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
video_pixel_values=batch.video_pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
video_grid_thw=batch.video_grid_thw,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.video_pixel_values is not None:
|
||||
batch.video_pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
if batch.video_grid_thw is not None:
|
||||
batch.video_grid_thw = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
|
Loading…
Reference in New Issue
Block a user