mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge b27749eba7
into afb6c728d8
This commit is contained in:
commit
33299d1ee9
2
.github/workflows/autodocs.yaml
vendored
2
.github/workflows/autodocs.yaml
vendored
@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Protocol Buffers compiler
|
- name: Install Protocol Buffers compiler
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev
|
sudo apt-get install -y protobuf-compiler libprotobuf-dev clang libavcodec-dev libavfilter-dev libavdevice-dev libavformat-dev libavutil-dev pkg-config
|
||||||
|
|
||||||
- name: Install Launcher
|
- name: Install Launcher
|
||||||
id: install-launcher
|
id: install-launcher
|
||||||
|
4
.github/workflows/tests.yaml
vendored
4
.github/workflows/tests.yaml
vendored
@ -43,7 +43,9 @@ jobs:
|
|||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install python3.11-dev -y
|
sudo apt install python3.11-dev python3.11-venv python3-pip clang libavcodec-dev libavfilter-dev libavdevice-dev libavformat-dev libavutil-dev pkg-config -y
|
||||||
|
export PKG_CONFIG_PATH=$PKG_CONFIG_PATH:/usr/lib/x86_64-linux-gnu/pkgconfig
|
||||||
|
python -m pip install --upgrade pip
|
||||||
make install-cpu
|
make install-cpu
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
|
89
Cargo.lock
generated
89
Cargo.lock
generated
@ -267,7 +267,7 @@ version = "0.23.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96"
|
checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen",
|
"bindgen 0.69.5",
|
||||||
"cc",
|
"cc",
|
||||||
"cmake",
|
"cmake",
|
||||||
"dunce",
|
"dunce",
|
||||||
@ -454,6 +454,24 @@ dependencies = [
|
|||||||
"which",
|
"which",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bindgen"
|
||||||
|
version = "0.70.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.6.0",
|
||||||
|
"cexpr",
|
||||||
|
"clang-sys",
|
||||||
|
"itertools 0.13.0",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
"shlex",
|
||||||
|
"syn 2.0.89",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bit-set"
|
name = "bit-set"
|
||||||
version = "0.5.3"
|
version = "0.5.3"
|
||||||
@ -487,6 +505,15 @@ version = "2.6.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitreader"
|
||||||
|
version = "0.3.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "886559b1e163d56c765bc3a985febb4eee8009f625244511d8ee3c432e08c066"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitstream-io"
|
name = "bitstream-io"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
@ -1194,6 +1221,15 @@ dependencies = [
|
|||||||
"zune-inflate",
|
"zune-inflate",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fallible_collections"
|
||||||
|
version = "0.4.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd"
|
||||||
|
dependencies = [
|
||||||
|
"hashbrown 0.13.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fancy-regex"
|
name = "fancy-regex"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
@ -1219,6 +1255,31 @@ dependencies = [
|
|||||||
"simd-adler32",
|
"simd-adler32",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ffmpeg-next"
|
||||||
|
version = "7.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "da02698288e0275e442a47fc12ca26d50daf0d48b15398ba5906f20ac2e2a9f9"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.6.0",
|
||||||
|
"ffmpeg-sys-next",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ffmpeg-sys-next"
|
||||||
|
version = "7.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2bc3234d0a4b2f7d083699d0860c6c9dd83713908771b60f94a96f8704adfe45"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen 0.70.1",
|
||||||
|
"cc",
|
||||||
|
"libc",
|
||||||
|
"num_cpus",
|
||||||
|
"pkg-config",
|
||||||
|
"vcpkg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fixedbitset"
|
name = "fixedbitset"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
@ -1512,6 +1573,15 @@ version = "0.12.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.13.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.14.5"
|
version = "0.14.5"
|
||||||
@ -2471,6 +2541,20 @@ dependencies = [
|
|||||||
"syn 2.0.89",
|
"syn 2.0.89",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mp4parse"
|
||||||
|
version = "0.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570"
|
||||||
|
dependencies = [
|
||||||
|
"bitreader",
|
||||||
|
"byteorder",
|
||||||
|
"fallible_collections",
|
||||||
|
"log",
|
||||||
|
"num-traits",
|
||||||
|
"static_assertions",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multimap"
|
name = "multimap"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
@ -4425,6 +4509,7 @@ dependencies = [
|
|||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
"csv",
|
"csv",
|
||||||
|
"ffmpeg-next",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
@ -4436,6 +4521,7 @@ dependencies = [
|
|||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"minijinja-contrib",
|
"minijinja-contrib",
|
||||||
|
"mp4parse",
|
||||||
"ngrok",
|
"ngrok",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
@ -4449,6 +4535,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sysinfo",
|
"sysinfo",
|
||||||
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
73
Dockerfile
73
Dockerfile
@ -20,6 +20,20 @@ FROM chef AS builder
|
|||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
python3.11-dev
|
python3.11-dev
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
ffmpeg \
|
||||||
|
libavcodec-dev \
|
||||||
|
libavfilter-dev \
|
||||||
|
libavdevice-dev \
|
||||||
|
libavformat-dev \
|
||||||
|
libavutil-dev \
|
||||||
|
libswscale-dev \
|
||||||
|
pkg-config \
|
||||||
|
libclang-dev \
|
||||||
|
clang \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -27,7 +41,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
RUN cargo chef cook --profile release-opt --features video --recipe-path recipe.json
|
||||||
|
|
||||||
ARG GIT_SHA
|
ARG GIT_SHA
|
||||||
ARG DOCKER_LABEL
|
ARG DOCKER_LABEL
|
||||||
@ -40,7 +54,7 @@ COPY benchmark benchmark
|
|||||||
COPY router router
|
COPY router router
|
||||||
COPY backends backends
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt --frozen
|
RUN cargo build --profile release-opt --frozen --features video
|
||||||
|
|
||||||
# Python builder
|
# Python builder
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
@ -61,18 +75,18 @@ ARG TARGETPLATFORM
|
|||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
ccache \
|
ccache \
|
||||||
curl \
|
curl \
|
||||||
git && \
|
git && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install conda
|
# Install conda
|
||||||
# translating Docker's TARGETPLATFORM into mamba arches
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
RUN case ${TARGETPLATFORM} in \
|
RUN case ${TARGETPLATFORM} in \
|
||||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
*) MAMBA_ARCH=x86_64 ;; \
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
esac && \
|
esac && \
|
||||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
RUN chmod +x ~/mambaforge.sh && \
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
@ -82,12 +96,15 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
# Install pytorch
|
# Install pytorch
|
||||||
# On arm64 we exit with an error code
|
# On arm64 we exit with an error code
|
||||||
RUN case ${TARGETPLATFORM} in \
|
RUN case ${TARGETPLATFORM} in \
|
||||||
"linux/arm64") exit 1 ;; \
|
"linux/arm64") exit 1 ;; \
|
||||||
*) /opt/conda/bin/conda update -y conda && \
|
*) /opt/conda/bin/conda update -y conda && \
|
||||||
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" "openssl>=3.3.0" ;; \
|
||||||
esac && \
|
esac && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
|
RUN /opt/conda/bin/conda install -y pyOpenSSL
|
||||||
|
|
||||||
|
|
||||||
# CUDA kernels builder image
|
# CUDA kernels builder image
|
||||||
FROM pytorch-install AS kernel-builder
|
FROM pytorch-install AS kernel-builder
|
||||||
|
|
||||||
@ -95,8 +112,8 @@ ARG MAX_JOBS=8
|
|||||||
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
ninja-build cmake \
|
ninja-build cmake \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Build Flash Attention CUDA kernels
|
# Build Flash Attention CUDA kernels
|
||||||
FROM kernel-builder AS flash-att-builder
|
FROM kernel-builder AS flash-att-builder
|
||||||
@ -188,12 +205,15 @@ ENV HF_HOME=/data \
|
|||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
make \
|
make \
|
||||||
curl \
|
curl \
|
||||||
git \
|
git \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Add ffmpeg libraries to the path
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
# Copy conda with PyTorch installed
|
# Copy conda with PyTorch installed
|
||||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
COPY --from=pytorch-install /opt/conda /opt/conda
|
||||||
@ -239,6 +259,8 @@ RUN cd server && \
|
|||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
# Required to find libpython within the rust binaries
|
# Required to find libpython within the rust binaries
|
||||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||||
|
ENV LD_PRELOAD="/opt/conda/lib/libcrypto.so.3"
|
||||||
|
|
||||||
# This is needed because exl2 tries to load flash-attn
|
# This is needed because exl2 tries to load flash-attn
|
||||||
# And fails with our builds.
|
# And fails with our builds.
|
||||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||||
@ -247,9 +269,9 @@ ENV EXLLAMA_NO_FLASH_ATTN=1
|
|||||||
# The binaries change on every build given we burn the SHA into them
|
# The binaries change on every build given we burn the SHA into them
|
||||||
# The deps change less often.
|
# The deps change less often.
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
g++ \
|
g++ \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -258,6 +280,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# Copy the ffmpeg libraries
|
||||||
|
COPY --from=builder /usr/lib/x86_64-linux-gnu/* /usr/lib/x86_64-linux-gnu-copy/
|
||||||
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu-copy"
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base AS sagemaker
|
FROM base AS sagemaker
|
||||||
|
@ -9,7 +9,7 @@ use thiserror::Error;
|
|||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
pub use v3::{Chunk, Image, Input, InputChunk, Video};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Health {
|
pub trait Health {
|
||||||
@ -79,6 +79,20 @@ impl ChunksToString for Vec<InputChunk> {
|
|||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("", mimetype, encoded))
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
}
|
}
|
||||||
|
Some(Chunk::Video(Video {
|
||||||
|
data,
|
||||||
|
mimetype,
|
||||||
|
width,
|
||||||
|
height: _,
|
||||||
|
frames: _,
|
||||||
|
})) => {
|
||||||
|
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
|
||||||
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!(
|
||||||
|
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
|
||||||
|
width, mimetype, encoded, mimetype
|
||||||
|
));
|
||||||
|
}
|
||||||
// We don't create empty chunks, so this should be unreachable.
|
// We don't create empty chunks, so this should be unreachable.
|
||||||
None => unreachable!("Chunks should never be empty"),
|
None => unreachable!("Chunks should never be empty"),
|
||||||
});
|
});
|
||||||
|
@ -8,6 +8,6 @@ pub use client::Client;
|
|||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters, Tokens,
|
StoppingCriteriaParameters, Tokens, Video,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
@ -301,6 +301,7 @@ impl TensorRtLlmBackendV2 {
|
|||||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
Chunk::Text(_) => Ok(()),
|
Chunk::Text(_) => Ok(()),
|
||||||
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||||
|
Chunk::Video(_) => Err(ValidationError(UnsupportedModality("video"))),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ pub use grpc_client::Client;
|
|||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters,
|
StoppingCriteriaParameters, Video,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
@ -439,6 +439,13 @@ impl State {
|
|||||||
data: image.data,
|
data: image.data,
|
||||||
mimetype: image.mimetype,
|
mimetype: image.mimetype,
|
||||||
}),
|
}),
|
||||||
|
Chunk::Video(video) => client::Chunk::Video(client::Video {
|
||||||
|
data: video.data,
|
||||||
|
mimetype: video.mimetype,
|
||||||
|
width: video.width,
|
||||||
|
height: video.height,
|
||||||
|
frames: video.num_frames,
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
|
@ -1922,6 +1922,24 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"video_url",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"video_url"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"video_url": {
|
||||||
|
"$ref": "#/components/schemas/Url"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
12
flake.nix
12
flake.nix
@ -115,15 +115,17 @@
|
|||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
benchmark
|
benchmark
|
||||||
launcher
|
cargo
|
||||||
router
|
|
||||||
server
|
|
||||||
client
|
client
|
||||||
|
clippy
|
||||||
|
ffmpeg
|
||||||
|
launcher
|
||||||
openssl.dev
|
openssl.dev
|
||||||
pkg-config
|
pkg-config
|
||||||
cargo
|
router
|
||||||
|
rustPlatform.bindgenHook
|
||||||
rustfmt
|
rustfmt
|
||||||
clippy
|
server
|
||||||
]
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
docker
|
docker
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1733450914,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.4.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
75
integration-tests/models/test_flash_qwen2_vl_video.py
Normal file
75
integration-tests/models/test_flash_qwen2_vl_video.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def qwen2_vl_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
max_input_length=10_000,
|
||||||
|
max_batch_prefill_tokens=10_000,
|
||||||
|
max_total_tokens=10_001,
|
||||||
|
cuda_graphs=[0],
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def qwen2_vl(qwen2_vl_handle):
|
||||||
|
await qwen2_vl_handle.health(300)
|
||||||
|
return qwen2_vl_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qwen2_vl_simpl(qwen2_vl, response_snapshot):
|
||||||
|
responses = requests.post(
|
||||||
|
f"{qwen2_vl.base_url}/v1/chat/completions",
|
||||||
|
headers=qwen2_vl.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video_url",
|
||||||
|
"video_url": {
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this video.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterate over the response in chunks
|
||||||
|
count = 0
|
||||||
|
full_text = ""
|
||||||
|
last_response = None
|
||||||
|
for chunk in responses.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
count += 1
|
||||||
|
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
|
||||||
|
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
print("=", line)
|
||||||
|
try:
|
||||||
|
response = json.loads(line)
|
||||||
|
# print(response)
|
||||||
|
last_response = response
|
||||||
|
full_text += response["choices"][0]["delta"]["content"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert last_response == response_snapshot
|
@ -19,6 +19,26 @@ defaultCrateOverrides
|
|||||||
};
|
};
|
||||||
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
||||||
|
|
||||||
|
ffmpeg-sys-next = attrs: {
|
||||||
|
nativeBuildInputs = [
|
||||||
|
pkg-config
|
||||||
|
];
|
||||||
|
buildInputs = [
|
||||||
|
rustPlatform.bindgenHook
|
||||||
|
ffmpeg
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
ffmpeg-next = attrs: {
|
||||||
|
# Somehow the variables that are passed are mangled, so they are not
|
||||||
|
# correctly passed to the ffmpeg-next build script. Worth investigating
|
||||||
|
# more since it's probably a bug in crate2nix or buildRustCrate.
|
||||||
|
postPatch = ''
|
||||||
|
substituteInPlace build.rs \
|
||||||
|
--replace-fail "DEP_FFMPEG_" "DEP_FFMPEG_SYS_NEXT_"
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
|
||||||
grpc-metadata = attrs: {
|
grpc-metadata = attrs: {
|
||||||
src = filter {
|
src = filter {
|
||||||
root = ../backends/grpc-metadata;
|
root = ../backends/grpc-metadata;
|
||||||
|
@ -5,9 +5,11 @@
|
|||||||
cmake,
|
cmake,
|
||||||
isort,
|
isort,
|
||||||
ninja,
|
ninja,
|
||||||
|
rustPlatform,
|
||||||
which,
|
which,
|
||||||
cudaPackages,
|
cudaPackages,
|
||||||
openssl,
|
openssl,
|
||||||
|
ffmpeg,
|
||||||
pkg-config,
|
pkg-config,
|
||||||
poetry,
|
poetry,
|
||||||
protobuf,
|
protobuf,
|
||||||
@ -26,6 +28,7 @@
|
|||||||
mkShell {
|
mkShell {
|
||||||
nativeBuildInputs =
|
nativeBuildInputs =
|
||||||
[
|
[
|
||||||
|
rustPlatform.bindgenHook
|
||||||
black
|
black
|
||||||
isort
|
isort
|
||||||
pkg-config
|
pkg-config
|
||||||
@ -53,6 +56,7 @@ mkShell {
|
|||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
openssl.dev
|
openssl.dev
|
||||||
|
ffmpeg
|
||||||
]
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
venvShellHook
|
venvShellHook
|
||||||
|
@ -64,12 +64,31 @@ message Image {
|
|||||||
string mimetype = 2;
|
string mimetype = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Video {
|
||||||
|
/// Binary video data (array of RGB data)
|
||||||
|
bytes data = 1;
|
||||||
|
|
||||||
|
/// Video MIME type.
|
||||||
|
string mimetype = 2;
|
||||||
|
|
||||||
|
/// Video width
|
||||||
|
uint32 width = 3;
|
||||||
|
|
||||||
|
/// Video height
|
||||||
|
uint32 height = 4;
|
||||||
|
|
||||||
|
/// Total number of frames
|
||||||
|
uint32 frames = 5;
|
||||||
|
}
|
||||||
|
|
||||||
message InputChunk {
|
message InputChunk {
|
||||||
oneof chunk {
|
oneof chunk {
|
||||||
/// Plain text data
|
/// Plain text data
|
||||||
string text = 1;
|
string text = 1;
|
||||||
/// Image data
|
/// Image data
|
||||||
Image image = 2;
|
Image image = 2;
|
||||||
|
/// Video URLs
|
||||||
|
Video video = 3;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,20 +14,23 @@ async-stream = "0.3.5"
|
|||||||
axum = { version = "0.7", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
|
ffmpeg-next = { version = "7.1.0", optional = true }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = { workspace = true }
|
metrics = { workspace = true }
|
||||||
metrics-exporter-prometheus = { workspace = true }
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
|
mp4parse = { version = "0.17.0", optional = true }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.13.0"
|
|
||||||
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
|
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
|
||||||
|
opentelemetry-otlp = "0.13.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.20", features = [] }
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
|
tempfile = { version = "3.10.1", optional = true }
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = [
|
tokio = { version = "1.32.0", features = [
|
||||||
@ -74,3 +77,4 @@ default = ["ngrok"]
|
|||||||
ngrok = ["dep:ngrok"]
|
ngrok = ["dep:ngrok"]
|
||||||
google = []
|
google = []
|
||||||
kserve = []
|
kserve = []
|
||||||
|
video = ["ffmpeg-next", "mp4parse", "tempfile"]
|
||||||
|
@ -1173,6 +1173,7 @@ pub struct Url {
|
|||||||
pub enum MessageChunk {
|
pub enum MessageChunk {
|
||||||
Text { text: String },
|
Text { text: String },
|
||||||
ImageUrl { image_url: Url },
|
ImageUrl { image_url: Url },
|
||||||
|
VideoUrl { video_url: Url },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
@ -1229,6 +1230,9 @@ impl From<Message> for TextMessage {
|
|||||||
.map(|chunk| match chunk {
|
.map(|chunk| match chunk {
|
||||||
MessageChunk::Text { text } => text,
|
MessageChunk::Text { text } => text,
|
||||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||||
|
MessageChunk::VideoUrl { video_url } => {
|
||||||
|
format!("<video>({})", video_url.url)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
|
@ -22,6 +22,15 @@ use tokio::sync::oneshot;
|
|||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
use {once_cell::sync::Lazy, regex::Regex};
|
use {once_cell::sync::Lazy, regex::Regex};
|
||||||
|
|
||||||
|
#[cfg(feature = "video")]
|
||||||
|
use ffmpeg_next::{
|
||||||
|
format::Pixel,
|
||||||
|
media::Type,
|
||||||
|
software::scaling::{context::Context, flag::Flags},
|
||||||
|
};
|
||||||
|
#[cfg(feature = "video")]
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
@ -536,6 +545,140 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "video"))]
|
||||||
|
pub fn fetch_video(
|
||||||
|
_input: &str,
|
||||||
|
_target_width: u32,
|
||||||
|
_target_height: u32,
|
||||||
|
) -> Result<ProcessedVideo, ValidationError> {
|
||||||
|
Err(ValidationError::VideoNotSupported)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "video")]
|
||||||
|
pub fn fetch_video(
|
||||||
|
input: &str,
|
||||||
|
target_width: u32,
|
||||||
|
target_height: u32,
|
||||||
|
) -> Result<ProcessedVideo, ValidationError> {
|
||||||
|
let (data, mimetype) =
|
||||||
|
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
||||||
|
let url = &input["<video>(".len()..input.len() - 1];
|
||||||
|
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
|
||||||
|
(data, "video/mp4".to_string())
|
||||||
|
} else if input.starts_with("<video>(data:") {
|
||||||
|
let content = &input["<video>(data:".len()..input.len() - 1];
|
||||||
|
let tokens: Vec<&str> = content.split(';').collect();
|
||||||
|
if tokens.len() != 2 {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
|
}
|
||||||
|
let mimetype = tokens[0];
|
||||||
|
let content = tokens[1];
|
||||||
|
if !content.starts_with("base64,") {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
|
}
|
||||||
|
let data = STANDARD.decode(&content["base64,".len()..])?;
|
||||||
|
(data, mimetype.to_string())
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
// init ffmpeg
|
||||||
|
ffmpeg_next::init().map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
|
// create temporary file for ffmpeg input
|
||||||
|
let mut temp_file = tempfile::NamedTempFile::new().map_err(ValidationError::IoError)?;
|
||||||
|
temp_file
|
||||||
|
.write_all(&data)
|
||||||
|
.map_err(ValidationError::IoError)?;
|
||||||
|
|
||||||
|
let mut ictx =
|
||||||
|
ffmpeg_next::format::input(&temp_file.path()).map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
|
let input = ictx
|
||||||
|
.streams()
|
||||||
|
.best(Type::Video)
|
||||||
|
.ok_or(ValidationError::FFmpegError)?;
|
||||||
|
let video_stream_index = input.index();
|
||||||
|
|
||||||
|
let context_decoder = ffmpeg_next::codec::context::Context::from_parameters(input.parameters())
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
let mut decoder = context_decoder
|
||||||
|
.decoder()
|
||||||
|
.video()
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
|
let width = target_width;
|
||||||
|
let height = target_height;
|
||||||
|
|
||||||
|
let mut scaler = Context::get(
|
||||||
|
decoder.format(),
|
||||||
|
decoder.width(), // original width
|
||||||
|
decoder.height(),
|
||||||
|
Pixel::RGB24,
|
||||||
|
width, // target width
|
||||||
|
height,
|
||||||
|
Flags::BILINEAR,
|
||||||
|
)
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
|
let mut frame_index = 0;
|
||||||
|
let mut captured_frame_index = 0;
|
||||||
|
let mut frames = vec![];
|
||||||
|
|
||||||
|
let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg_next::decoder::Video,
|
||||||
|
raw_fps: f32|
|
||||||
|
-> Result<(), ffmpeg_next::Error> {
|
||||||
|
let mut decoded = ffmpeg_next::util::frame::video::Video::empty();
|
||||||
|
let fps = raw_fps.floor();
|
||||||
|
while decoder.receive_frame(&mut decoded).is_ok() {
|
||||||
|
let mut rgb_frame = ffmpeg_next::util::frame::video::Video::empty();
|
||||||
|
scaler.run(&decoded, &mut rgb_frame)?;
|
||||||
|
if frame_index as f32 % fps == 0.0 {
|
||||||
|
captured_frame_index += 1;
|
||||||
|
// Create new buffer without padding
|
||||||
|
let mut frame_data =
|
||||||
|
Vec::with_capacity((rgb_frame.width() * rgb_frame.height() * 3) as usize);
|
||||||
|
let src_data = rgb_frame.data(0);
|
||||||
|
let row_size = rgb_frame.width() as usize * 3;
|
||||||
|
|
||||||
|
// Copy each row without padding
|
||||||
|
for y in 0..rgb_frame.height() as usize {
|
||||||
|
let start = y * rgb_frame.stride(0);
|
||||||
|
let end = start + row_size;
|
||||||
|
frame_data.extend_from_slice(&src_data[start..end]);
|
||||||
|
}
|
||||||
|
frames.push(frame_data);
|
||||||
|
}
|
||||||
|
frame_index += 1;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
for (stream, packet) in ictx.packets() {
|
||||||
|
// Floor the fps to get a whole number
|
||||||
|
let fps = (stream.rate().numerator() as f32 / stream.rate().denominator() as f32).floor();
|
||||||
|
|
||||||
|
if stream.index() == video_stream_index {
|
||||||
|
decoder
|
||||||
|
.send_packet(&packet)
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
receive_and_process_decoded_frames(&mut decoder, fps)
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
decoder
|
||||||
|
.send_eof()
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
|
Ok(ProcessedVideo {
|
||||||
|
mimetype,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
frames,
|
||||||
|
sampled_frames: captured_frame_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
if input.starts_with(" || input.starts_with(" {
|
if input.starts_with(" || input.starts_with(" {
|
||||||
let url = &input["..input.len() - 1];
|
let url = &input["..input.len() - 1];
|
||||||
@ -624,6 +767,26 @@ fn image_tokens(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn video_tokens(config: &Config, height: u32, width: u32, sampled_frames: f32) -> String {
|
||||||
|
use Config::*;
|
||||||
|
|
||||||
|
match config {
|
||||||
|
Qwen2Vl(_config) => {
|
||||||
|
let min_frames = 2_f32;
|
||||||
|
let max_frames = 256_f32;
|
||||||
|
// make sure the frames are within the range and are even
|
||||||
|
let nframes = (sampled_frames).max(min_frames).min(max_frames);
|
||||||
|
let nframes = (nframes / 2.0).round() as usize * 2;
|
||||||
|
let num_tokens = nframes * height as usize * width as usize / 1541;
|
||||||
|
format!(
|
||||||
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
|
"<|video_pad|>".repeat(num_tokens)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => unimplemented!("Video tokens are not supported for this model configuration"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||||
match config {
|
match config {
|
||||||
Config::Idefics2(_) => {
|
Config::Idefics2(_) => {
|
||||||
@ -645,6 +808,10 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
|
// Add video regex
|
||||||
|
static VIDEO_RE: Lazy<Regex> =
|
||||||
|
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
||||||
|
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
|
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
|
||||||
@ -652,6 +819,53 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
|
||||||
|
// handle video content first
|
||||||
|
for chunk in VIDEO_RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let processed_video = match config {
|
||||||
|
Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) => {
|
||||||
|
let default_target_width = 224;
|
||||||
|
let default_target_height = 224;
|
||||||
|
fetch_video(
|
||||||
|
&inputs[chunk_start..chunk_end],
|
||||||
|
default_target_width,
|
||||||
|
default_target_height,
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
Qwen2Vl(_) => {
|
||||||
|
let target_width = 360;
|
||||||
|
let target_height = 420;
|
||||||
|
fetch_video(&inputs[chunk_start..chunk_end], target_width, target_height)?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
unreachable!("Video tokens are not supported for this model configuration")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
input_chunks.push(Chunk::Video(Video {
|
||||||
|
data: processed_video.frames.iter().flatten().cloned().collect(),
|
||||||
|
mimetype: processed_video.mimetype.clone(),
|
||||||
|
width: processed_video.width,
|
||||||
|
height: processed_video.height,
|
||||||
|
num_frames: processed_video.frames.len() as u32,
|
||||||
|
}));
|
||||||
|
let video_tokens = video_tokens(
|
||||||
|
config,
|
||||||
|
processed_video.height,
|
||||||
|
processed_video.width,
|
||||||
|
processed_video.sampled_frames as f32,
|
||||||
|
);
|
||||||
|
tokenizer_query.push_str(&video_tokens);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle image content after video content
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
@ -660,7 +874,10 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
input_chunks.push(Chunk::Image(Image {
|
||||||
|
data,
|
||||||
|
mimetype: mimetype.clone(),
|
||||||
|
}));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
@ -683,7 +900,6 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
|
|
||||||
Ok((encoding, input_chunks))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, bool, Option<usize>),
|
(String, bool, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||||
@ -696,10 +912,28 @@ pub struct Image {
|
|||||||
pub mimetype: String,
|
pub mimetype: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ProcessedVideo {
|
||||||
|
mimetype: String,
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
frames: Vec<Vec<u8>>, // RGB frames
|
||||||
|
sampled_frames: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
|
pub struct Video {
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub mimetype: String,
|
||||||
|
pub width: u32,
|
||||||
|
pub height: u32,
|
||||||
|
pub num_frames: u32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub enum Chunk {
|
pub enum Chunk {
|
||||||
Text(String),
|
Text(String),
|
||||||
Image(Image),
|
Image(Image),
|
||||||
|
Video(Video),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert input chunks to a stringly-typed input for backwards
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
@ -718,6 +952,20 @@ impl ChunksToString for Vec<Chunk> {
|
|||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("", mimetype, encoded))
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
}
|
}
|
||||||
|
Chunk::Video(Video {
|
||||||
|
data,
|
||||||
|
mimetype,
|
||||||
|
width,
|
||||||
|
height: _,
|
||||||
|
num_frames: _,
|
||||||
|
}) => {
|
||||||
|
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
|
||||||
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!(
|
||||||
|
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
|
||||||
|
width, mimetype, encoded, mimetype
|
||||||
|
));
|
||||||
|
}
|
||||||
});
|
});
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
@ -846,6 +1094,18 @@ pub enum ValidationError {
|
|||||||
FailedFetchImage(#[from] reqwest::Error),
|
FailedFetchImage(#[from] reqwest::Error),
|
||||||
#[error("{0} modality is not supported")]
|
#[error("{0} modality is not supported")]
|
||||||
UnsupportedModality(&'static str),
|
UnsupportedModality(&'static str),
|
||||||
|
#[error("invalid video content: {0}")]
|
||||||
|
InvalidVideoContent(String),
|
||||||
|
#[error("could not parse MP4 file")]
|
||||||
|
MP4Error,
|
||||||
|
#[error("no video stream found")]
|
||||||
|
NoVideoStream,
|
||||||
|
#[error("io error: {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
#[error("ffmpeg error")]
|
||||||
|
FFmpegError,
|
||||||
|
#[error("video not supported")]
|
||||||
|
VideoNotSupported,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -81,6 +81,8 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
|
@ -751,6 +751,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
|
@ -181,6 +181,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
@ -411,6 +411,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
self,
|
self,
|
||||||
batch_input_ids: torch.Tensor,
|
batch_input_ids: torch.Tensor,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if batch_input_ids.dim() == 1:
|
if batch_input_ids.dim() == 1:
|
||||||
@ -424,8 +425,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
device=batch_input_ids.device,
|
device=batch_input_ids.device,
|
||||||
)
|
)
|
||||||
d = batch_input_ids.device
|
d = batch_input_ids.device
|
||||||
if image_grid_thw is not None:
|
|
||||||
image_index = 0
|
# Handle both image and video tokens
|
||||||
|
if image_grid_thw is not None or video_grid_thw is not None:
|
||||||
|
vision_index = 0
|
||||||
llm_pos_ids_list = []
|
llm_pos_ids_list = []
|
||||||
|
|
||||||
for i, input_ids in enumerate(batch_input_ids):
|
for i, input_ids in enumerate(batch_input_ids):
|
||||||
@ -433,24 +436,39 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
input_ids == self.vision_start_token_id
|
input_ids == self.vision_start_token_id
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
vision_tokens = input_ids[vision_start_indices + 1]
|
||||||
# only copy the sum of the image tokens GPU<->CPU
|
|
||||||
|
# only copy the sum of the image and video tokens GPU<->CPU
|
||||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
image_count = (vision_tokens == self.image_token_id).sum().item()
|
||||||
|
video_count = (vision_tokens == self.video_token_id).sum().item()
|
||||||
|
|
||||||
current_pos = 0
|
current_pos = 0
|
||||||
for _ in range(image_count):
|
for _ in range(image_count + video_count):
|
||||||
# copy the value position of the next image token from GPU<->CPU
|
# copy the value position of the next image or video token from GPU<->CPU
|
||||||
next_image_pos = (
|
next_vision_pos = (
|
||||||
(input_ids[current_pos:] == self.image_token_id)
|
(
|
||||||
|
(input_ids[current_pos:] == self.image_token_id)
|
||||||
|
| (input_ids[current_pos:] == self.video_token_id)
|
||||||
|
)
|
||||||
.nonzero()[0]
|
.nonzero()[0]
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
||||||
time_steps, height, width = image_grid_thw[image_index].clone()
|
is_video = (
|
||||||
|
input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||||
|
)
|
||||||
|
grid_thw = (
|
||||||
|
video_grid_thw[vision_index]
|
||||||
|
if is_video
|
||||||
|
else image_grid_thw[vision_index]
|
||||||
|
)
|
||||||
|
|
||||||
|
time_steps, height, width = grid_thw.clone()
|
||||||
height //= self.spatial_merge_size
|
height //= self.spatial_merge_size
|
||||||
width //= self.spatial_merge_size
|
width //= self.spatial_merge_size
|
||||||
|
|
||||||
# calculate the length of the text and image tokens
|
# calculate the length of the text and image tokens
|
||||||
text_length = next_image_pos
|
text_length = next_vision_pos - current_pos
|
||||||
start_idx = (
|
start_idx = (
|
||||||
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
)
|
)
|
||||||
@ -460,7 +478,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
||||||
llm_pos_ids_list.append(text_pos_ids)
|
llm_pos_ids_list.append(text_pos_ids)
|
||||||
|
|
||||||
# image position ids
|
# vision position ids
|
||||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
||||||
height * width
|
height * width
|
||||||
)
|
)
|
||||||
@ -473,16 +491,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
height * time_steps
|
height * time_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
image_pos_ids = (
|
vision_pos_ids = (
|
||||||
torch.stack([t_indices, h_indices, w_indices])
|
torch.stack([t_indices, h_indices, w_indices])
|
||||||
+ text_length
|
+ text_length
|
||||||
+ start_idx
|
+ start_idx
|
||||||
)
|
)
|
||||||
llm_pos_ids_list.append(image_pos_ids)
|
llm_pos_ids_list.append(vision_pos_ids)
|
||||||
|
|
||||||
current_pos += next_image_pos + time_steps * height * width
|
current_pos = next_vision_pos + time_steps * height * width
|
||||||
image_index += 1
|
vision_index += 1
|
||||||
|
|
||||||
|
# Handle remaining text if any
|
||||||
if current_pos < batch_input_ids.size(1):
|
if current_pos < batch_input_ids.size(1):
|
||||||
st_idx = (
|
st_idx = (
|
||||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
@ -515,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
video_pixel_values: torch.FloatTensor = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
@ -525,13 +545,27 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
):
|
):
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if video_pixel_values is not None and len(video_pixel_values) > 0:
|
||||||
|
vision_embeds = self.visual(
|
||||||
|
video_pixel_values, grid_thw=video_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
|
vision_token_mask = input_ids == self.video_token_id
|
||||||
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
if pixel_values is not None:
|
vision_embeds = self.visual(
|
||||||
image_embeds = self.visual(
|
pixel_values,
|
||||||
pixel_values, grid_thw=image_grid_thw
|
grid_thw=(
|
||||||
).squeeze(0)
|
torch.cat([image_grid_thw, video_grid_thw])
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
if video_grid_thw is not None
|
||||||
|
else image_grid_thw
|
||||||
|
),
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
# Apply embeddings to image tokens
|
||||||
|
vision_token_mask = input_ids == self.image_token_id
|
||||||
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -148,7 +148,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
|
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
video_inputs = None
|
||||||
|
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
@ -160,8 +161,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs, _video_inputs = (
|
||||||
pb.requests, tokenizer, processor, config
|
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
|
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
|
||||||
|
@ -68,4 +68,6 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
|||||||
image_inputs = new_image_inputs
|
image_inputs = new_image_inputs
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
|
||||||
|
video_inputs = None
|
||||||
|
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ from text_generation_server.utils.log import log_master
|
|||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
import math
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -76,6 +78,15 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def video_text_replacement(processor, video_input, config) -> str:
|
||||||
|
if config.model_type == "qwen2_vl":
|
||||||
|
num_pads = video_input.pixel_values.shape[0] // 4
|
||||||
|
padding = "<|video_pad|>" * num_pads
|
||||||
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement_fixup(config, text: str) -> str:
|
def image_text_replacement_fixup(config, text: str) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
return text.replace(
|
return text.replace(
|
||||||
@ -138,29 +149,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
|
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||||
|
def smart_nframes(
|
||||||
|
fps: int,
|
||||||
|
nframes: int,
|
||||||
|
min_frames: int,
|
||||||
|
max_frames: int,
|
||||||
|
total_frames: int,
|
||||||
|
video_fps: int | float,
|
||||||
|
) -> int:
|
||||||
|
if nframes:
|
||||||
|
nframes = round(nframes / 2) * 2
|
||||||
|
else:
|
||||||
|
min_frames = math.ceil(min_frames / 2) * 2
|
||||||
|
max_frames = math.floor(max_frames / 2) * 2
|
||||||
|
nframes = total_frames / video_fps * fps
|
||||||
|
nframes = min(max(nframes, min_frames), max_frames)
|
||||||
|
nframes = round(nframes / 2) * 2
|
||||||
|
if not (2 <= nframes and nframes <= total_frames):
|
||||||
|
raise ValueError(
|
||||||
|
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
|
||||||
|
)
|
||||||
|
return nframes
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
video_pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
image_grid_thw: Optional[torch.Tensor]
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
|
video_grid_thw: Optional[torch.Tensor]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.video_pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.video_pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -171,6 +212,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
# can make the image splits the same size. And we need the final
|
# can make the image splits the same size. And we need the final
|
||||||
# sizes to insert correct number of image tokens.
|
# sizes to insert correct number of image tokens.
|
||||||
images = []
|
images = []
|
||||||
|
videos = []
|
||||||
for r in requests:
|
for r in requests:
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
@ -190,6 +232,30 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
else:
|
else:
|
||||||
images.append([image])
|
images.append([image])
|
||||||
|
elif chunk_type == "video":
|
||||||
|
if config.model_type == "qwen2_vl":
|
||||||
|
video_frame_buf = np.frombuffer(
|
||||||
|
chunk.video.data, dtype=np.uint8
|
||||||
|
)
|
||||||
|
num_bytes = len(video_frame_buf)
|
||||||
|
bytes_per_frame = num_bytes // chunk.video.frames
|
||||||
|
|
||||||
|
# iterate over with a stride the size of a frame
|
||||||
|
frames = []
|
||||||
|
for i in range(chunk.video.frames):
|
||||||
|
frame = video_frame_buf[
|
||||||
|
i * bytes_per_frame : (i + 1) * bytes_per_frame
|
||||||
|
]
|
||||||
|
frame = frame.reshape(
|
||||||
|
chunk.video.height, chunk.video.width, 3
|
||||||
|
)
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
video_frame_buf = np.stack(frames)
|
||||||
|
frame_nchw_tensor = torch.from_numpy(video_frame_buf).permute(
|
||||||
|
0, 3, 1, 2
|
||||||
|
)
|
||||||
|
videos.append(frame_nchw_tensor)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
@ -198,6 +264,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
|
video_inputs = None
|
||||||
|
if videos:
|
||||||
|
try:
|
||||||
|
video_inputs = processor.image_processor(
|
||||||
|
videos,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to process video: {e}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
video_inputs = None
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
image_id = 0
|
image_id = 0
|
||||||
@ -212,9 +291,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
processor, image_inputs, config, image_id
|
processor, image_inputs, config, image_id
|
||||||
)
|
)
|
||||||
image_id += 1
|
image_id += 1
|
||||||
|
elif chunk_type == "video":
|
||||||
|
full_text += video_text_replacement(processor, video_inputs, config)
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
@ -225,7 +305,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
add_special_tokens=not config.model_type == "paligemma",
|
add_special_tokens=not config.model_type == "paligemma",
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
@ -237,10 +317,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor, config
|
pb.requests, tokenizer, processor, config
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
if video_inputs is not None:
|
||||||
|
if "pixel_values" in video_inputs:
|
||||||
|
batch.video_pixel_values = video_inputs["pixel_values"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
if "image_grid_thw" in video_inputs:
|
||||||
|
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.video_grid_thw = None
|
||||||
|
else:
|
||||||
|
batch.video_pixel_values = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
|
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
if "pixel_attention_mask" in image_inputs:
|
if "pixel_attention_mask" in image_inputs:
|
||||||
@ -257,6 +350,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||||
else:
|
else:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
if "video_grid_thw" in image_inputs:
|
||||||
|
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.video_grid_thw = None
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
@ -367,7 +464,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
if self.model.config.model_type == "qwen2_vl":
|
if self.model.config.model_type == "qwen2_vl":
|
||||||
if position_ids.dim() == 1 and batch.prefilling:
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
position_ids = self.model.get_position_ids(
|
position_ids = self.model.get_position_ids(
|
||||||
input_ids, batch.image_grid_thw
|
input_ids, batch.image_grid_thw, batch.video_grid_thw
|
||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
@ -420,20 +517,26 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
|
video_pixel_values=batch.video_pixel_values,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
image_grid_thw=batch.image_grid_thw,
|
image_grid_thw=batch.image_grid_thw,
|
||||||
|
video_grid_thw=batch.video_grid_thw,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.video_pixel_values is not None:
|
||||||
|
batch.video_pixel_values = None
|
||||||
if batch.pixel_attention_mask is not None:
|
if batch.pixel_attention_mask is not None:
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
if batch.image_sizes is not None:
|
if batch.image_sizes is not None:
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
if batch.video_grid_thw is not None:
|
||||||
|
batch.video_grid_thw = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
Loading…
Reference in New Issue
Block a user