mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Add llama4 (#3145)
* initial changes * Add support for other vlm * cleanup comment * Improve attn_implementation * Add comments for support of models * add model * add model * fixes and improvements * update docker * Add cache position * Add tests * remove redundant changes * remove tr version * Upgrade doc + fix linting. * Fixing the CI. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
3d059f91ab
commit
d9bb9bebc9
@ -65,7 +65,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
ENV PATH="$PATH:/root/.local/bin"
|
ENV PATH="$PATH:/root/.local/bin"
|
||||||
RUN uv python install ${PYTHON_VERSION}
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} pip setuptools packaging
|
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
|
||||||
ENV VIRTUAL_ENV=/usr/src/.venv/
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
|
|
||||||
@ -193,6 +193,9 @@ RUN cd server && \
|
|||||||
pwd && \
|
pwd && \
|
||||||
text-generation-server --help
|
text-generation-server --help
|
||||||
|
|
||||||
|
# This shouldn't be necessary.
|
||||||
|
# RUN uv pip install torchvision --no-deps
|
||||||
|
|
||||||
# Copy build artifacts from flash attention builder
|
# Copy build artifacts from flash attention builder
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
@ -9,6 +9,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||||||
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
|
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
|
- [Llama4](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
|
@ -0,0 +1,613 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 100,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2721,
|
||||||
|
"logprob": -0.21582031,
|
||||||
|
"special": false,
|
||||||
|
"text": " people"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21807,
|
||||||
|
"logprob": -0.26953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " died"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.95703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -1.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -1.3828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7284,
|
||||||
|
"logprob": -0.011291504,
|
||||||
|
"special": false,
|
||||||
|
"text": "191"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 36,
|
||||||
|
"logprob": -0.011413574,
|
||||||
|
"special": false,
|
||||||
|
"text": "8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18938,
|
||||||
|
"logprob": -0.23242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27650,
|
||||||
|
"logprob": -0.0010070801,
|
||||||
|
"special": false,
|
||||||
|
"text": " pandemic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": -0.69140625,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 114059,
|
||||||
|
"logprob": -1.4375,
|
||||||
|
"special": false,
|
||||||
|
"text": " Estimating"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.24316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10593,
|
||||||
|
"logprob": -0.37304688,
|
||||||
|
"special": false,
|
||||||
|
"text": " death"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 49973,
|
||||||
|
"logprob": -0.025390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " toll"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 323,
|
||||||
|
"logprob": -0.27539062,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.057617188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -0.040527344,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7284,
|
||||||
|
"logprob": -0.00050735474,
|
||||||
|
"special": false,
|
||||||
|
"text": "191"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 36,
|
||||||
|
"logprob": -9.298325e-06,
|
||||||
|
"special": false,
|
||||||
|
"text": "8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18938,
|
||||||
|
"logprob": -0.09863281,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27650,
|
||||||
|
"logprob": -0.0011749268,
|
||||||
|
"special": false,
|
||||||
|
"text": " pandemic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 373,
|
||||||
|
"logprob": -0.32421875,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8210,
|
||||||
|
"logprob": -0.58203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " difficult"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2895,
|
||||||
|
"logprob": -0.40429688,
|
||||||
|
"special": false,
|
||||||
|
"text": " because"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 323,
|
||||||
|
"logprob": -1.2734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 49119,
|
||||||
|
"logprob": -0.51171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " incomplete"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13308,
|
||||||
|
"logprob": -0.38085938,
|
||||||
|
"special": false,
|
||||||
|
"text": " records"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 341,
|
||||||
|
"logprob": -0.55859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2895,
|
||||||
|
"logprob": -0.765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " because"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 323,
|
||||||
|
"logprob": -1.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2304,
|
||||||
|
"logprob": -1.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": " fact"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 511,
|
||||||
|
"logprob": -0.004638672,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2233,
|
||||||
|
"logprob": -0.953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " many"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 323,
|
||||||
|
"logprob": -0.87890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.60546875,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6759,
|
||||||
|
"logprob": -1.6484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " extra"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40657,
|
||||||
|
"logprob": -0.00022125244,
|
||||||
|
"special": false,
|
||||||
|
"text": " deaths"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1610,
|
||||||
|
"logprob": -0.67578125,
|
||||||
|
"special": false,
|
||||||
|
"text": " were"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 702,
|
||||||
|
"logprob": -0.30664062,
|
||||||
|
"special": false,
|
||||||
|
"text": " not"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 48692,
|
||||||
|
"logprob": -0.1953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " attributed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 328,
|
||||||
|
"logprob": -0.0079956055,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.515625,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18938,
|
||||||
|
"logprob": -0.0040893555,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": -0.083496094,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13618,
|
||||||
|
"logprob": -0.515625,
|
||||||
|
"special": false,
|
||||||
|
"text": " Many"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 22215,
|
||||||
|
"logprob": -1.5703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " experts"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11081,
|
||||||
|
"logprob": -0.96875,
|
||||||
|
"special": false,
|
||||||
|
"text": " believe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 511,
|
||||||
|
"logprob": -0.1171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.25195312,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -0.828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7284,
|
||||||
|
"logprob": -0.00010967255,
|
||||||
|
"special": false,
|
||||||
|
"text": "191"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 36,
|
||||||
|
"logprob": -8.535385e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18938,
|
||||||
|
"logprob": -0.056152344,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27650,
|
||||||
|
"logprob": -0.0007095337,
|
||||||
|
"special": false,
|
||||||
|
"text": " pandemic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26132,
|
||||||
|
"logprob": -0.18847656,
|
||||||
|
"special": false,
|
||||||
|
"text": " killed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1867,
|
||||||
|
"logprob": -0.71484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " between"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -0.0062561035,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1175,
|
||||||
|
"logprob": -0.009277344,
|
||||||
|
"special": false,
|
||||||
|
"text": "50"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 341,
|
||||||
|
"logprob": -0.15332031,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -8.34465e-07,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1135,
|
||||||
|
"logprob": -0.00065612793,
|
||||||
|
"special": false,
|
||||||
|
"text": "100"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5534,
|
||||||
|
"logprob": -1.4066696e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " million"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2721,
|
||||||
|
"logprob": -0.0008392334,
|
||||||
|
"special": false,
|
||||||
|
"text": " people"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": -0.54296875,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 372,
|
||||||
|
"logprob": -1.8046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 140680,
|
||||||
|
"logprob": -0.578125,
|
||||||
|
"special": false,
|
||||||
|
"text": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 200006,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": true,
|
||||||
|
"text": "<|header_end|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 368,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 954,
|
||||||
|
"logprob": -0.032226562,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -4.4345856e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7284,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "191"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 36,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18938,
|
||||||
|
"logprob": -0.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27650,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " pandemic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -0.0072021484,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1437,
|
||||||
|
"logprob": -0.0001707077,
|
||||||
|
"special": false,
|
||||||
|
"text": " also"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5711,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " known"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 486,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " as"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -5.9604645e-07,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25836,
|
||||||
|
"logprob": -1.4305115e-06,
|
||||||
|
"special": false,
|
||||||
|
"text": " Spanish"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18938,
|
||||||
|
"logprob": -0.0015029907,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -0.0052490234,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 373,
|
||||||
|
"logprob": -0.3125,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26078,
|
||||||
|
"logprob": -0.21289062,
|
||||||
|
"special": false,
|
||||||
|
"text": " indeed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1085,
|
||||||
|
"logprob": -0.080078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " one"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 323,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2167,
|
||||||
|
"logprob": -0.20117188,
|
||||||
|
"special": false,
|
||||||
|
"text": " most"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92679,
|
||||||
|
"logprob": -0.12695312,
|
||||||
|
"special": false,
|
||||||
|
"text": " devastating"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 854,
|
||||||
|
"logprob": -0.25976562,
|
||||||
|
"special": false,
|
||||||
|
"text": " public"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4500,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " health"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 93079,
|
||||||
|
"logprob": -0.50390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " crises"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6023,
|
||||||
|
"logprob": -0.0015182495,
|
||||||
|
"special": false,
|
||||||
|
"text": " human"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7068,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " history"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": -0.0012664795,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 114059,
|
||||||
|
"logprob": -0.004119873,
|
||||||
|
"special": false,
|
||||||
|
"text": " Estimating"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 290,
|
||||||
|
"logprob": -0.00033569336,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6318,
|
||||||
|
"logprob": -0.20117188,
|
||||||
|
"special": false,
|
||||||
|
"text": " exact"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " people died in the 1918 flu pandemic. Estimating the death toll of the 1918 flu pandemic is difficult because of incomplete records and because of the fact that many of the extra deaths were not attributed to the flu. Many experts believe that the 1918 flu pandemic killed between 50 and 100 million people. Iassistant\n\nThe 1918 flu pandemic, also known as the Spanish flu, is indeed one of the most devastating public health crises in human history. Estimating the exact"
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image is a blank white space with no visible objects or features. It appears to be an empty or placeholder image, devoid of any content or visual elements.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1743861910,
|
||||||
|
"id": "",
|
||||||
|
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 34,
|
||||||
|
"prompt_tokens": 166,
|
||||||
|
"total_tokens": 200
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image is a blank white space with no visible objects or features.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1743861909,
|
||||||
|
"id": "",
|
||||||
|
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 15,
|
||||||
|
"prompt_tokens": 166,
|
||||||
|
"total_tokens": 181
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image is a black background with no discernible objects or features. The image appears to be a blank or empty space, devoid of any visual elements.\n\n**Key Features:**\n\n* **Color:** The dominant color of the image is black.\n* **Objects:** There are no visible objects or shapes in the image.\n* **Background:** The background of the image is a solid black color.\n\n**Conclusion:**\nIn summary, the image is a simple and empty visual representation with a black background and no",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1743861909,
|
||||||
|
"id": "",
|
||||||
|
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 100,
|
||||||
|
"prompt_tokens": 166,
|
||||||
|
"total_tokens": 266
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image shows a brown cow standing on the beach with a white face and black and white marking on its ears. The cow has a white patch around its nose and mouth. The ocean and blue sky are in the background.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1743863057,
|
||||||
|
"id": "",
|
||||||
|
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 46,
|
||||||
|
"prompt_tokens": 164,
|
||||||
|
"total_tokens": 210
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image does not depict a dog; it shows a cow standing on a beach. Therefore, there is no breed of a dog to identify.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1743863056,
|
||||||
|
"id": "",
|
||||||
|
"model": "ll-re/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 30,
|
||||||
|
"prompt_tokens": 168,
|
||||||
|
"total_tokens": 198
|
||||||
|
}
|
||||||
|
}
|
155
integration-tests/models/test_transformers_llama4.py
Normal file
155
integration-tests/models/test_transformers_llama4.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama4_handle(launcher):
|
||||||
|
with launcher("ll-re/Llama-4-Scout-17B-16E-Instruct", num_shard=8) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama4(flash_llama4_handle):
|
||||||
|
await flash_llama4_handle.health(300)
|
||||||
|
return flash_llama4_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama4(flash_llama4, response_snapshot):
|
||||||
|
response = await flash_llama4.generate(
|
||||||
|
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||||
|
seed=42,
|
||||||
|
max_new_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " people died in the 1918 flu pandemic. Estimating the death toll of the 1918 flu pandemic is difficult because of incomplete records and because of the fact that many of the extra deaths were not attributed to the flu. Many experts believe that the 1918 flu pandemic killed between 50 and 100 million people. Iassistant\n\nThe 1918 flu pandemic, also known as the Spanish flu, is indeed one of the most devastating public health crises in human history. Estimating the exact"
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 100
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama4_image_cow_dog(flash_llama4, response_snapshot):
|
||||||
|
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||||
|
response = await flash_llama4.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is the breed of the dog in the image?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "The image does not depict a dog; it shows a cow standing on a beach. Therefore, there is no breed of a dog to identify."
|
||||||
|
)
|
||||||
|
assert response.usage["completion_tokens"] == 30
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama4_image_cow(flash_llama4, response_snapshot):
|
||||||
|
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||||
|
response = await flash_llama4.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "The image shows a brown cow standing on the beach with a white face and black and white marking on its ears. The cow has a white patch around its nose and mouth. The ocean and blue sky are in the background."
|
||||||
|
)
|
||||||
|
assert response.usage["completion_tokens"] == 46
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to convert a Pillow image to a base64 data URL
|
||||||
|
def image_to_data_url(img: Image.Image, fmt: str) -> str:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format=fmt)
|
||||||
|
img_data = buffer.getvalue()
|
||||||
|
b64_str = base64.b64encode(img_data).decode("utf-8")
|
||||||
|
mime_type = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
|
||||||
|
return f"data:{mime_type};base64,{b64_str}"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama4_image_base64_rgba(flash_llama4, response_snapshot):
|
||||||
|
# Create an empty 100x100 PNG image with alpha (transparent background)
|
||||||
|
img = Image.new("RGBA", (100, 100), (0, 0, 0, 0))
|
||||||
|
data_url = image_to_data_url(img, "PNG")
|
||||||
|
response = await flash_llama4.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": data_url}},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What do you see in this transparent image?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama4_image_base64_rgb_png(flash_llama4, response_snapshot):
|
||||||
|
# Create an empty 100x100 PNG image without alpha (white background)
|
||||||
|
img = Image.new("RGB", (100, 100), (255, 255, 255))
|
||||||
|
data_url = image_to_data_url(img, "PNG")
|
||||||
|
response = await flash_llama4.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": data_url}},
|
||||||
|
{"type": "text", "text": "What do you see in this plain image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama4_image_base64_rgb_jpg(flash_llama4, response_snapshot):
|
||||||
|
# Create an empty 100x100 JPEG image (white background)
|
||||||
|
img = Image.new("RGB", (100, 100), (255, 255, 255))
|
||||||
|
data_url = image_to_data_url(img, "JPEG")
|
||||||
|
response = await flash_llama4.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": data_url}},
|
||||||
|
{"type": "text", "text": "What do you see in this JPEG image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
@ -1,4 +1,5 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
@ -103,6 +104,141 @@ impl LlavaNext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Llama4VisionConfig {
|
||||||
|
image_size: usize,
|
||||||
|
patch_size: usize,
|
||||||
|
pixel_shuffle_ratio: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Llama4 {
|
||||||
|
text_config: TextConfig,
|
||||||
|
vision_config: Llama4VisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gcd(a: usize, b: usize) -> usize {
|
||||||
|
if b == 0 {
|
||||||
|
a
|
||||||
|
} else {
|
||||||
|
gcd(b, a % b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_factors(dividend: usize) -> HashSet<usize> {
|
||||||
|
let mut factors_set = HashSet::new();
|
||||||
|
|
||||||
|
for i in 1..=((dividend as f64).sqrt() as usize) {
|
||||||
|
if dividend % i == 0 {
|
||||||
|
factors_set.insert(i);
|
||||||
|
factors_set.insert(dividend / i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
factors_set
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_supported_resolutions(max_num_chunks: usize, height: usize) -> Vec<(usize, usize)> {
|
||||||
|
let patch_size = height;
|
||||||
|
|
||||||
|
let mut asp_dict: HashMap<(usize, usize), Vec<(usize, usize)>> = HashMap::new();
|
||||||
|
|
||||||
|
for chunk_size in (1..=max_num_chunks).rev() {
|
||||||
|
let mut _factors: Vec<_> = get_factors(chunk_size).into_iter().collect();
|
||||||
|
_factors.sort();
|
||||||
|
let _asp_ratios: Vec<(usize, usize)> =
|
||||||
|
_factors.iter().map(|&f| (f, chunk_size / f)).collect();
|
||||||
|
|
||||||
|
for (h, w) in _asp_ratios {
|
||||||
|
let divisor = gcd(h, w);
|
||||||
|
let key = (h / divisor, w / divisor); // reduced aspect ratio as key
|
||||||
|
|
||||||
|
asp_dict.entry(key).or_default().push((h, w));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut possible_resolutions = vec![];
|
||||||
|
|
||||||
|
for (_key, value) in asp_dict {
|
||||||
|
for (h, w) in value {
|
||||||
|
possible_resolutions.push((h * patch_size, w * patch_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
possible_resolutions
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_best_fit(
|
||||||
|
original_height: usize,
|
||||||
|
original_width: usize,
|
||||||
|
possible_resolutions: &[(usize, usize)],
|
||||||
|
resize_to_max_canvas: bool,
|
||||||
|
) -> (usize, usize) {
|
||||||
|
let orig_h = original_height as f32;
|
||||||
|
let orig_w = original_width as f32;
|
||||||
|
|
||||||
|
let mut scales = Vec::with_capacity(possible_resolutions.len());
|
||||||
|
|
||||||
|
for &(h, w) in possible_resolutions.iter() {
|
||||||
|
let scale_h = h as f32 / orig_h;
|
||||||
|
let scale_w = w as f32 / orig_w;
|
||||||
|
let scale = scale_h.min(scale_w);
|
||||||
|
scales.push(scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
let upscaling_options: Vec<f32> = scales.iter().copied().filter(|&s| s >= 1.0).collect();
|
||||||
|
let selected_scale = if !upscaling_options.is_empty() {
|
||||||
|
if resize_to_max_canvas {
|
||||||
|
upscaling_options.into_iter().fold(f32::MIN, f32::max)
|
||||||
|
} else {
|
||||||
|
upscaling_options.into_iter().fold(f32::MAX, f32::min)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let downscaling_options: Vec<f32> = scales.iter().copied().filter(|&s| s < 1.0).collect();
|
||||||
|
downscaling_options.into_iter().fold(f32::MIN, f32::max)
|
||||||
|
};
|
||||||
|
|
||||||
|
let chosen_canvas: Vec<(usize, usize)> = possible_resolutions
|
||||||
|
.iter()
|
||||||
|
.zip(scales.iter())
|
||||||
|
.filter(|&(_, &s)| (s - selected_scale).abs() < f32::EPSILON)
|
||||||
|
.map(|(&(h, w), _)| (h, w))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if chosen_canvas.len() > 1 {
|
||||||
|
chosen_canvas
|
||||||
|
.into_iter()
|
||||||
|
.min_by_key(|(h, w)| h * w)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
chosen_canvas[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Llama4 {
|
||||||
|
pub fn image_size(&self) -> usize {
|
||||||
|
self.vision_config.image_size
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn patch_size(&self) -> usize {
|
||||||
|
self.vision_config.patch_size
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pixel_shuffle_ratio(&self) -> f64 {
|
||||||
|
self.vision_config.pixel_shuffle_ratio
|
||||||
|
}
|
||||||
|
pub fn get_aspect_ratios(&self, height: usize, width: usize) -> (usize, usize) {
|
||||||
|
let patch_size = self.vision_config.image_size;
|
||||||
|
// How to avoid hardcoding this?
|
||||||
|
let max_chunks = 15;
|
||||||
|
let supported = find_supported_resolutions(max_chunks, patch_size);
|
||||||
|
let (target_h, target_w) = get_best_fit(height, width, &supported, false);
|
||||||
|
(target_h / patch_size, target_w / patch_size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct ClipVisionModel {
|
pub struct ClipVisionModel {
|
||||||
@ -258,6 +394,7 @@ pub enum Config {
|
|||||||
Phi3,
|
Phi3,
|
||||||
Phimoe,
|
Phimoe,
|
||||||
Llama,
|
Llama,
|
||||||
|
Llama4(Llama4),
|
||||||
Baichuan,
|
Baichuan,
|
||||||
Paligemma(Paligemma),
|
Paligemma(Paligemma),
|
||||||
Gemma,
|
Gemma,
|
||||||
|
@ -179,6 +179,7 @@ pub enum HubPreprocessorConfig {
|
|||||||
Idefics2Processor(Idefics2Preprocessor),
|
Idefics2Processor(Idefics2Preprocessor),
|
||||||
Idefics3Processor(Idefics2Preprocessor),
|
Idefics3Processor(Idefics2Preprocessor),
|
||||||
Gemma3Processor(Gemma3Processor),
|
Gemma3Processor(Gemma3Processor),
|
||||||
|
Llama4Processor(Llama4Processor),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubPreprocessorConfig {
|
impl HubPreprocessorConfig {
|
||||||
@ -200,6 +201,12 @@ pub struct Gemma3Processor {
|
|||||||
do_image_splitting: bool,
|
do_image_splitting: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Llama4Processor {
|
||||||
|
#[serde(default)]
|
||||||
|
do_image_splitting: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
pub struct HubProcessorConfig {
|
pub struct HubProcessorConfig {
|
||||||
pub chat_template: Option<ChatTemplateVersions>,
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
|
@ -687,6 +687,47 @@ fn image_tokens(
|
|||||||
}
|
}
|
||||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||||
|
Llama4(config) => {
|
||||||
|
const IMAGE_START: &str = "<|image_start|>";
|
||||||
|
const IMAGE: &str = "<|image|>";
|
||||||
|
const IMAGE_END: &str = "<|image_end|>";
|
||||||
|
const PATCH: &str = "<|patch|>";
|
||||||
|
const TILE_X_SEP: &str = "<|tile_x_separator|>";
|
||||||
|
const TILE_Y_SEP: &str = "<|tile_y_separator|>";
|
||||||
|
|
||||||
|
let image_height = config.image_size();
|
||||||
|
let patch_size = config.patch_size();
|
||||||
|
let pixel_shuffle_ratio = config.pixel_shuffle_ratio();
|
||||||
|
let downsample_ratio =
|
||||||
|
(1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize;
|
||||||
|
|
||||||
|
let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width);
|
||||||
|
let image_width = image_height; // Assuming pixel shape: [H][W][C]
|
||||||
|
|
||||||
|
let num_patches_per_chunk =
|
||||||
|
(image_height / patch_size) * (image_width / patch_size) / downsample_ratio;
|
||||||
|
|
||||||
|
let mut img_string = String::new();
|
||||||
|
img_string.push_str(IMAGE_START);
|
||||||
|
|
||||||
|
if ratio_h * ratio_w > 1 {
|
||||||
|
for _yy in 0..ratio_h {
|
||||||
|
for xx in 0..ratio_w {
|
||||||
|
img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
|
||||||
|
if xx < ratio_w - 1 {
|
||||||
|
img_string.push_str(TILE_X_SEP);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
img_string.push_str(TILE_Y_SEP);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
img_string.push_str(IMAGE);
|
||||||
|
img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
|
||||||
|
img_string.push_str(IMAGE_END);
|
||||||
|
|
||||||
|
img_string
|
||||||
|
}
|
||||||
Qwen2Vl(config) => format!(
|
Qwen2Vl(config) => format!(
|
||||||
"<|vision_start|>{:?}<|vision_end|>",
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
||||||
@ -730,8 +771,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Paligemma(_)
|
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Llama4(_)
|
||||||
| LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
|
| Paligemma(_) | LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
|
||||||
) => {
|
) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
@ -32,7 +32,8 @@ dependencies = [
|
|||||||
"tokenizers>=0.20.3",
|
"tokenizers>=0.20.3",
|
||||||
"typer>=0.15.1",
|
"typer>=0.15.1",
|
||||||
"transformers>=4.49.0",
|
"transformers>=4.49.0",
|
||||||
"huggingface-hub>=0.29.0",
|
"huggingface-hub>=0.30.1",
|
||||||
|
"hf-xet>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
@ -206,7 +206,13 @@ try:
|
|||||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
TransformersFlashCausalLM,
|
TransformersFlashCausalLM,
|
||||||
)
|
)
|
||||||
except ImportError:
|
from text_generation_server.models.transformers_flash_vlm import (
|
||||||
|
TransformersFlashVlmCausalLM,
|
||||||
|
TransformersGemma3VlmCausalLM,
|
||||||
|
TransformersLlama4VlmCausalLM,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||||
FLASH_TRANSFORMERS_BACKEND = False
|
FLASH_TRANSFORMERS_BACKEND = False
|
||||||
|
|
||||||
|
|
||||||
@ -244,6 +250,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Llama",
|
"name": "Llama",
|
||||||
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
}
|
}
|
||||||
|
LLAMA4 = {
|
||||||
|
"type": "llama4",
|
||||||
|
"name": "Llama4",
|
||||||
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
|
}
|
||||||
PHI3 = {
|
PHI3 = {
|
||||||
"type": "phi3",
|
"type": "phi3",
|
||||||
"name": "Phi 3",
|
"name": "Phi 3",
|
||||||
@ -648,7 +659,6 @@ def get_model(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
@ -1017,7 +1027,23 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif model_type == LLAMA4:
|
||||||
|
if FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
||||||
|
|
||||||
|
return TransformersLlama4VlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Llama4Model,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
processor_kwargs={
|
||||||
|
"use_fast": True,
|
||||||
|
"size": {"height": 336, "width": 336},
|
||||||
|
},
|
||||||
|
)
|
||||||
elif model_type == BAICHUAN:
|
elif model_type == BAICHUAN:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -1155,7 +1181,6 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GEMMA3:
|
elif model_type == GEMMA3:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
# TODO: Use VlmCausalLM when image support is added.
|
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Gemma3ForConditionalGeneration,
|
model_class=Gemma3ForConditionalGeneration,
|
||||||
@ -1173,12 +1198,15 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
return TransformersFlashCausalLM.fallback(
|
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
|
||||||
|
|
||||||
|
return TransformersGemma3VlmCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
|
Gemma3Model,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -1483,6 +1511,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == QWEN2_VL:
|
if model_type == QWEN2_VL:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2VLForConditionalGeneration,
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
@ -1495,7 +1524,23 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
# TODO: Uncomment when transformers is refactored
|
||||||
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
|
||||||
|
|
||||||
|
# return TransformersQwen2VlmCausalLM.fallback(
|
||||||
|
# model_id,
|
||||||
|
# Qwen2VLModel,
|
||||||
|
# revision,
|
||||||
|
# quantize=quantize,
|
||||||
|
# speculator=speculator,
|
||||||
|
# dtype=torch.bfloat16,
|
||||||
|
# trust_remote_code=trust_remote_code,
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
|
||||||
if model_type == QWEN2_5_VL:
|
if model_type == QWEN2_5_VL:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2_5VLForConditionalGeneration,
|
model_class=Qwen2_5VLForConditionalGeneration,
|
||||||
@ -1510,6 +1555,21 @@ def get_model(
|
|||||||
config_class=Qwen2_5_VLConfig,
|
config_class=Qwen2_5_VLConfig,
|
||||||
processor_class=Qwen2_5_VLProcessor,
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
)
|
)
|
||||||
|
# TODO: Uncomment when transformers is refactored
|
||||||
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
# return TransformersQwen2VlmCausalLM.fallback(
|
||||||
|
# model_id,
|
||||||
|
# Qwen2VLModel,
|
||||||
|
# revision,
|
||||||
|
# quantize=quantize,
|
||||||
|
# speculator=speculator,
|
||||||
|
# dtype=torch.bfloat16,
|
||||||
|
# trust_remote_code=trust_remote_code,
|
||||||
|
# config_class=Qwen2_5_VLConfig,
|
||||||
|
# processor_class=Qwen2_5_VLProcessor,
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return MllamaCausalLM(
|
return MllamaCausalLM(
|
||||||
@ -1524,6 +1584,20 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
# TODO: Uncomment when transformers is refactored and cross attn is added
|
||||||
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
# from transformers import MllamaForConditionalGeneration as MllamaModel
|
||||||
|
|
||||||
|
# return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
# model_id,
|
||||||
|
# MllamaModel,
|
||||||
|
# revision,
|
||||||
|
# quantize=quantize,
|
||||||
|
# speculator=speculator,
|
||||||
|
# dtype=torch.bfloat16,
|
||||||
|
# trust_remote_code=trust_remote_code,
|
||||||
|
# batch_class=MllamaCausalLMBatch,
|
||||||
|
# )
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
@ -1542,6 +1616,19 @@ def get_model(
|
|||||||
# VRAM usage.
|
# VRAM usage.
|
||||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import Idefics2ForConditionalGeneration as Idefics2Model
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Idefics2Model,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == IDEFICS3:
|
if model_type == IDEFICS3:
|
||||||
@ -1560,6 +1647,19 @@ def get_model(
|
|||||||
# VRAM usage.
|
# VRAM usage.
|
||||||
processor_kwargs={"size": {"longest_edge": 1456}},
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import Idefics3ForConditionalGeneration as Idefics3Model
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Idefics3Model,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
@ -1578,9 +1678,21 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
batch_class=PaliGemmaBatch,
|
batch_class=PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
else:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
PaliGemmaModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
@ -1593,6 +1705,18 @@ def get_model(
|
|||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import LlavaNextForConditionalGeneration as LlavaNextModel
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
LlavaNextModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||||
|
|
||||||
|
@ -1344,9 +1344,6 @@ class FlashCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def max_past(self) -> int:
|
|
||||||
return getattr(self.model, "max_past", None)
|
|
||||||
|
|
||||||
def init_kv_cache(
|
def init_kv_cache(
|
||||||
self,
|
self,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
@ -1792,12 +1789,6 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
if sorted_padded_bs:
|
if sorted_padded_bs:
|
||||||
|
@ -36,6 +36,7 @@ def tgi_flash_attention_forward(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
@ -72,6 +73,7 @@ def tgi_flash_attention_forward(
|
|||||||
max_s,
|
max_s,
|
||||||
kv_scales=kv_scales,
|
kv_scales=kv_scales,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
|
window_size_left=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
@ -157,7 +159,14 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
self.num_layers = model.config.num_hidden_layers
|
self.num_layers = model.config.num_hidden_layers
|
||||||
self.num_heads = model.config.num_attention_heads
|
self.num_heads = model.config.num_attention_heads
|
||||||
self.num_kv_heads = model.config.num_key_value_heads
|
self.num_kv_heads = model.config.num_key_value_heads
|
||||||
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
# Some models use GQA and different sizes for o_proj
|
||||||
|
# and q_proj, that allows for that.
|
||||||
|
if hasattr(model.config, "head_dim"):
|
||||||
|
self.head_size = model.config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = (
|
||||||
|
model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
# Skip it for models in the exception list
|
# Skip it for models in the exception list
|
||||||
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
|
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||||
|
566
server/text_generation_server/models/transformers_flash_vlm.py
Normal file
566
server/text_generation_server/models/transformers_flash_vlm.py
Normal file
@ -0,0 +1,566 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoTokenizer, AutoProcessor
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
|
||||||
|
from text_generation_server.utils import initialize_torch_distributed
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
||||||
|
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
||||||
|
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
||||||
|
# to internal constraints it was not (yet?) possible to circumvent
|
||||||
|
REPLICATED_ATTENTION_MODELS = [
|
||||||
|
"olmo2",
|
||||||
|
"phi3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# # Qwen2VL
|
||||||
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
|
# "tgi"
|
||||||
|
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
|
# "eager"
|
||||||
|
# ]
|
||||||
|
def tgi_flash_attention_forward(
|
||||||
|
module,
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
use_sdpa: Optional[bool] = False,
|
||||||
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
|
):
|
||||||
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
|
||||||
|
# Take care of updating the cache in-place
|
||||||
|
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||||
|
|
||||||
|
_, num_heads, head_dim = query_states.shape
|
||||||
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
if not use_sdpa:
|
||||||
|
attn_output = attention(
|
||||||
|
query=query_states,
|
||||||
|
key=key_states,
|
||||||
|
value=value_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
window_size_left=sliding_window,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
||||||
|
max_length = max(lengths)
|
||||||
|
attention_mask = attention_mask[:, :, :, :max_length]
|
||||||
|
enable_gqa = query_states.shape[1] != key_states.shape[1]
|
||||||
|
# Split tensors using vectorized split
|
||||||
|
query_list = torch.split(query_states, lengths.tolist(), dim=0)
|
||||||
|
key_list = torch.split(key_states, lengths.tolist(), dim=0)
|
||||||
|
value_list = torch.split(value_states, lengths.tolist(), dim=0)
|
||||||
|
|
||||||
|
padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True)
|
||||||
|
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
|
||||||
|
padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True)
|
||||||
|
|
||||||
|
padded_query = padded_query.transpose(1, 2).contiguous()
|
||||||
|
padded_key = padded_key.transpose(1, 2).contiguous()
|
||||||
|
padded_value = padded_value.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
padded_query,
|
||||||
|
padded_key,
|
||||||
|
padded_value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
scale=softmax_scale,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(
|
||||||
|
1, 2
|
||||||
|
) # [batch_size, seq_len, num_heads, head_dim]
|
||||||
|
max_seq_len = padded_query.size(2)
|
||||||
|
seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
lengths_tensor = torch.tensor(
|
||||||
|
lengths, device=padded_query.device
|
||||||
|
).unsqueeze(1)
|
||||||
|
mask = seq_range < lengths_tensor # [batch, max_seq_len]
|
||||||
|
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query_states,
|
||||||
|
kv_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
|
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: implement
|
||||||
|
# tgi_cross_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
model_class,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
processor_class=AutoProcessor,
|
||||||
|
processor_kwargs=None,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
|
batch_class=VlmCausalLMBatch,
|
||||||
|
):
|
||||||
|
self.batch_class = batch_class
|
||||||
|
self.quantize = quantize
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
if speculator:
|
||||||
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Flash `Transformers` modeling backend is not available on cpu."
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if processor_kwargs is None:
|
||||||
|
processor_kwargs = {}
|
||||||
|
|
||||||
|
self.processor = processor_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_implementation = {
|
||||||
|
"text_config": "tgi",
|
||||||
|
"vision_config": "sdpa",
|
||||||
|
}
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
device_map=device if world_size == 1 else None,
|
||||||
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
self.config = model.config
|
||||||
|
config = model.config
|
||||||
|
|
||||||
|
# VLM models define the config we care about in their text_config
|
||||||
|
text_config = getattr(model.config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
config = text_config
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if model.config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
elif model.config.eos_token_id is not None and isinstance(
|
||||||
|
model.config.eos_token_id, int
|
||||||
|
):
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
|
# Some models use GQA and different sizes for o_proj
|
||||||
|
# and q_proj, that allows for that.
|
||||||
|
if hasattr(config, "head_dim"):
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
# Skip it for models in the exception list
|
||||||
|
if config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||||
|
self.num_heads = self.num_heads // self.process_group.size()
|
||||||
|
self.num_kv_heads = (
|
||||||
|
self.num_kv_heads // self.process_group.size()
|
||||||
|
if self.num_kv_heads > 1
|
||||||
|
else self.num_kv_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
self.kv_cache = []
|
||||||
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_prefill_state,
|
||||||
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_state = create_decode_state(
|
||||||
|
device=device,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
# Those will never change and will be used in the forwards
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_kv_heads, dtype=torch.int32, device=device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
# This means no scale
|
||||||
|
self.kv_scales = KVScales(
|
||||||
|
torch.tensor(1.0, device=device),
|
||||||
|
torch.tensor(1.0, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip FlashCausalLM init.
|
||||||
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||||
|
# We first copy the original model.forward because we still need it in the monkey patch
|
||||||
|
self.model.original_forward = self.model.forward
|
||||||
|
self.model.forward = self._model_forward
|
||||||
|
self.model.get_position_ids = self.get_position_ids
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
def post_process_outputs(self, logits, lm_head_indices):
|
||||||
|
return logits.squeeze(dim=0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
model_class,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
batch_class: Optional[type] = VlmCausalLMBatch,
|
||||||
|
processor_kwargs: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=model_class,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=batch_class,
|
||||||
|
processor_kwargs=processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
|
inputs = self.pre_process_inputs(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
|
logits = self.model.original_forward(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
position_ids=inputs["position_ids"],
|
||||||
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
return_dict=True,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
|
use_sdpa=inputs.get("use_sdpa", False),
|
||||||
|
cache_position=inputs.get("cache_position", None),
|
||||||
|
).logits
|
||||||
|
|
||||||
|
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||||
|
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
|
def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):
|
||||||
|
if image_grid_thw is None:
|
||||||
|
return (
|
||||||
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(1, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||||
|
vision_start_token_id = self.config.vision_start_token_id
|
||||||
|
vision_end_token_id = self.config.vision_end_token_id
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
input_ids_len = input_ids.shape[0]
|
||||||
|
|
||||||
|
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
|
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
|
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||||
|
prev_vision_end = torch.cat(
|
||||||
|
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||||
|
)
|
||||||
|
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||||
|
vision_widths_max = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||||
|
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||||
|
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||||
|
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
for i, _ in enumerate(vision_segments):
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[i][0],
|
||||||
|
image_grid_thw[i][1] // spatial_merge_size,
|
||||||
|
image_grid_thw[i][2] // spatial_merge_size,
|
||||||
|
)
|
||||||
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||||
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||||
|
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||||
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
|
# offset by the position of the last vision segment
|
||||||
|
im = image_position_ids + vision_segment_lengths[i]
|
||||||
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
|
# create position ids for each text segment
|
||||||
|
text_ranges = [
|
||||||
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
+ text_segment_lengths[i]
|
||||||
|
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||||
|
]
|
||||||
|
|
||||||
|
full_llm_pos_ids_list = [
|
||||||
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
|
]
|
||||||
|
# import ipdb
|
||||||
|
|
||||||
|
# ipdb.set_trace()
|
||||||
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
|
final_text_len = input_ids_len - vision_ends[-1]
|
||||||
|
if final_text_len > 0:
|
||||||
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
|
position_ids = (
|
||||||
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
|
)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def post_process_outputs(self, logits, lm_head_indices):
|
||||||
|
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
||||||
|
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
||||||
|
return {"input_ids": input_ids, "position_ids": position_ids}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
|
def get_attention_mask(self, input_ids, cu_seqlen_prefill):
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = self.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||||
|
batch_size = len(lengths)
|
||||||
|
|
||||||
|
sequence_length = max(lengths)
|
||||||
|
target_length = sequence_length
|
||||||
|
# Create the padding mask from the computed lengths.
|
||||||
|
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||||
|
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
||||||
|
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
|
||||||
|
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||||
|
|
||||||
|
# Build the base causal mask (for non-image tokens):
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones(
|
||||||
|
(sequence_length, sequence_length), dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
) # [batch, sequence_length, sequence_length]
|
||||||
|
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
|
||||||
|
|
||||||
|
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||||
|
input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
image_token_mask = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
|
||||||
|
)
|
||||||
|
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine the causal base mask and the bidirectional mask.
|
||||||
|
combined_mask = torch.logical_or(
|
||||||
|
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
|
||||||
|
).to(device)
|
||||||
|
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
|
||||||
|
|
||||||
|
full_attention_mask = torch.zeros(
|
||||||
|
(batch_size, 1, sequence_length, target_length),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bool,
|
||||||
|
)
|
||||||
|
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||||
|
|
||||||
|
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||||
|
|
||||||
|
return final_attention_mask
|
||||||
|
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
attention_mask = self.get_attention_mask(
|
||||||
|
input_ids.squeeze(0), cu_seqlen_prefill
|
||||||
|
)
|
||||||
|
inputs["attention_mask"] = attention_mask
|
||||||
|
inputs["use_sdpa"] = True
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill)
|
||||||
|
inputs["cache_position"] = position_ids
|
||||||
|
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||||
|
return inputs
|
@ -29,6 +29,33 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
|||||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
|
||||||
|
"""
|
||||||
|
Create a structured string representation of image tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_patches: Number of patches in the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String with appropriate image tokens
|
||||||
|
"""
|
||||||
|
img_string = "<|image_start|>"
|
||||||
|
ratio_h, ratio_w = aspect_ratio
|
||||||
|
if ratio_h * ratio_w > 1:
|
||||||
|
for yy in range(ratio_h):
|
||||||
|
for xx in range(ratio_w):
|
||||||
|
img_string += "<|patch|>" * num_patches_per_chunk
|
||||||
|
if xx < ratio_w - 1:
|
||||||
|
img_string += "<|tile_x_separator|>"
|
||||||
|
|
||||||
|
img_string += "<|tile_y_separator|>"
|
||||||
|
img_string += "<|image|>"
|
||||||
|
img_string += "<|patch|>" * num_patches_per_chunk
|
||||||
|
img_string += "<|image_end|>"
|
||||||
|
|
||||||
|
return img_string
|
||||||
|
|
||||||
|
|
||||||
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
||||||
def _prompt_split_image(
|
def _prompt_split_image(
|
||||||
*,
|
*,
|
||||||
@ -134,6 +161,23 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
num_pads = 256
|
num_pads = 256
|
||||||
padding = "<image_soft_token>" * num_pads
|
padding = "<image_soft_token>" * num_pads
|
||||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||||
|
elif config.model_type == "llama4":
|
||||||
|
patch_size = config.vision_config.patch_size
|
||||||
|
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||||
|
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||||
|
aspect_ratios = image_input["aspect_ratios"][image_id]
|
||||||
|
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
||||||
|
|
||||||
|
num_patches_per_chunk = int(
|
||||||
|
(image_height // patch_size)
|
||||||
|
* (image_width // patch_size)
|
||||||
|
// downsample_ratio
|
||||||
|
)
|
||||||
|
tokens_for_this_image = prompt_split_image_llama4(
|
||||||
|
aspect_ratios, num_patches_per_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokens_for_this_image
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -252,6 +296,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
elif config.model_type == "gemma3":
|
elif config.model_type == "gemma3":
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
elif config.model_type == "llama4":
|
||||||
|
images.append(image)
|
||||||
else:
|
else:
|
||||||
images.append([image])
|
images.append([image])
|
||||||
else:
|
else:
|
||||||
@ -285,7 +331,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
processor, image_inputs, config, image_id
|
processor, image_inputs, config, image_id
|
||||||
)
|
)
|
||||||
image_id += 1
|
image_id += 1
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
full_text,
|
full_text,
|
||||||
@ -372,9 +418,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return self.batch_class
|
return self.batch_class
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: VlmCausalLMBatch,
|
||||||
@ -442,12 +485,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
@ -707,9 +707,24 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240 },
|
{ url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hf-xet"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/64/46/db229dddc55121478105940b610fef1b466c414da02be9d4daa5602a2527/hf_xet-1.0.0.tar.gz", hash = "sha256:5e0ca891ce599fd753e7ffbdc182207d952a93e6252eeb92118475d6866bb093", size = 257192 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e7/0a/c16f8766fa3cd520292b1a765e9b50b8390bce4c2ed7657db9534551f5ed/hf_xet-1.0.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6106304f92bbce7c9b8509f6f735f2e8ce95e4dc32af8876e874c48b15ca1903", size = 5001841 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e3/9f/cca55edd85d03fc98c743bcc093965740a7440e909779c558039d6838f03/hf_xet-1.0.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:4d0bc7a3e6c1d21fcbb48e8726e3b19a2460e95971375e55e9a5f73ec7079a86", size = 4805318 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/0b/28bda7ac9d699dcfb96f628aa135ddca3f0f77e9716351aab2b83966f957/hf_xet-1.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23dee64f114ea9a272ff71a6a755e025d7a075a6a0dbf6da0990fe9211831ccf", size = 53504907 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/04/ef1f7249a813841d193cbab2ef4d1d7d67c66c61d21d45223a72fdc5c88e/hf_xet-1.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d5f160550508ed87783d1eca4288602a713d1c45ec517d937acb9d93120f0cab", size = 52410434 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/81/b3/e7abec2619ecd9d1c743adfe79fa69cf84530f530969daf3dc804efef65b/hf_xet-1.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5ebd79db87df0b9d3607e7c9a6bb0662c10e36992f522f66b1d2a7fe93f53f27", size = 53465113 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/df/82/b51f3b6e5c6f33e91220c37b17760229704c58e79ab0fcfd0fd3b55803d3/hf_xet-1.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8e6d2625971b4affad634835db82d5392f38de874205a9573e0dd3f0f9cb136f", size = 53461632 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/95/d2/32defba26d995f7acdc4fe3e5911473b25aff5b75c5a2532786435a709e8/hf_xet-1.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:b446964bd75eb7f6b4d983c47241b2023eadfad1f56137ed00e1ca6fc278faec", size = 4121808 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "0.29.1"
|
version = "0.30.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "filelock" },
|
{ name = "filelock" },
|
||||||
@ -720,9 +735,9 @@ dependencies = [
|
|||||||
{ name = "tqdm" },
|
{ name = "tqdm" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776 }
|
sdist = { url = "https://files.pythonhosted.org/packages/78/be/049689a7197630e75c4bb53021cb209a56617c9bf39b3a0950650d1f96e1/huggingface_hub-0.30.1.tar.gz", hash = "sha256:f379e8b8d0791295602538856638460ae3cf679c7f304201eb80fb98c771950e", size = 400784 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049 },
|
{ url = "https://files.pythonhosted.org/packages/99/e3/2232d0e726d4d6ea69643b9593d97d0e7e6ea69c2fe9ed5de34d476c1c47/huggingface_hub-0.30.1-py3-none-any.whl", hash = "sha256:0f6aa5ec5a4e68e5b9e45d556b4e5ea180c58f5a5ffa734e7f38c9d573028959", size = 481170 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2563,6 +2578,7 @@ dependencies = [
|
|||||||
{ name = "grpcio-reflection" },
|
{ name = "grpcio-reflection" },
|
||||||
{ name = "grpcio-status" },
|
{ name = "grpcio-status" },
|
||||||
{ name = "hf-transfer" },
|
{ name = "hf-transfer" },
|
||||||
|
{ name = "hf-xet" },
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
{ name = "kernels" },
|
{ name = "kernels" },
|
||||||
{ name = "loguru" },
|
{ name = "loguru" },
|
||||||
@ -2628,7 +2644,8 @@ requires-dist = [
|
|||||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" },
|
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.51.1,<2.0" },
|
||||||
{ name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" },
|
{ name = "grpcio-tools", marker = "extra == 'gen'", specifier = ">=1.69.0" },
|
||||||
{ name = "hf-transfer", specifier = ">=0.1.8" },
|
{ name = "hf-transfer", specifier = ">=0.1.8" },
|
||||||
{ name = "huggingface-hub", specifier = ">=0.29.0" },
|
{ name = "hf-xet", specifier = ">=1.0.0" },
|
||||||
|
{ name = "huggingface-hub", specifier = ">=0.30.1" },
|
||||||
{ name = "kernels", specifier = ">=0.2.1" },
|
{ name = "kernels", specifier = ">=0.2.1" },
|
||||||
{ name = "loguru", specifier = ">=0.7.3" },
|
{ name = "loguru", specifier = ">=0.7.3" },
|
||||||
{ name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" },
|
{ name = "mypy-protobuf", marker = "extra == 'gen'", specifier = ">=3.6.0" },
|
||||||
|
Loading…
Reference in New Issue
Block a user