Add cache position

This commit is contained in:
Mohit Sharma 2025-04-03 10:26:48 +00:00
parent 22aaf497b7
commit 7a57b01002
3 changed files with 17 additions and 4 deletions

View File

@ -193,6 +193,10 @@ RUN cd server && \
pwd && \
text-generation-server --help
RUN uv pip install torchvision --no-deps
COPY transformers-4.51.0.dev0-py3-none-any.whl .
RUN uv pip install transformers-4.51.0.dev0-py3-none-any.whl --no-deps
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages

View File

@ -209,6 +209,7 @@ try:
from text_generation_server.models.transformers_flash_vlm import (
TransformersFlashVlmCausalLM,
TransformersGemma3VlmCausalLM,
TransformersLlama4VlmCausalLM,
)
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
@ -1030,7 +1031,7 @@ def get_model(
if FLASH_TRANSFORMERS_BACKEND:
from transformers import Llama4ForConditionalGeneration as Llama4Model
return TransformersFlashVlmCausalLM.fallback(
return TransformersLlama4VlmCausalLM.fallback(
model_id,
Llama4Model,
revision,
@ -1038,10 +1039,8 @@ def get_model(
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
# how to load from preprocessor_config.json
processor_kwargs={
"use_fast": True,
"max_patches": 15,
"size": {"height": 336, "width": 336},
},
)

View File

@ -202,7 +202,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
attn_implementation = {
"text_config": "tgi",
"vision_config": "eager",
"vision_config": "sdpa",
}
model = model_class.from_pretrained(
@ -395,6 +395,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
image_grid_thw=image_grid_thw,
attention_mask=inputs.get("attention_mask", None),
use_sdpa=inputs.get("use_sdpa", False),
cache_position=inputs.get("cache_position", None)
).logits
logits = self.post_process_outputs(logits, lm_head_indices)
@ -555,3 +556,12 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
inputs["use_sdpa"] = True
return inputs
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
inputs = super().pre_process_inputs(
input_ids, position_ids, cu_seqlen_prefill
)
inputs["cache_position"] = position_ids
return inputs