diff --git a/Dockerfile b/Dockerfile index 232e22d1..570968db 100644 --- a/Dockerfile +++ b/Dockerfile @@ -193,6 +193,10 @@ RUN cd server && \ pwd && \ text-generation-server --help +RUN uv pip install torchvision --no-deps +COPY transformers-4.51.0.dev0-py3-none-any.whl . +RUN uv pip install transformers-4.51.0.dev0-py3-none-any.whl --no-deps + # Copy build artifacts from flash attention builder COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a816e544..84437bf3 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -209,6 +209,7 @@ try: from text_generation_server.models.transformers_flash_vlm import ( TransformersFlashVlmCausalLM, TransformersGemma3VlmCausalLM, + TransformersLlama4VlmCausalLM, ) except ImportError as e: log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}") @@ -1030,7 +1031,7 @@ def get_model( if FLASH_TRANSFORMERS_BACKEND: from transformers import Llama4ForConditionalGeneration as Llama4Model - return TransformersFlashVlmCausalLM.fallback( + return TransformersLlama4VlmCausalLM.fallback( model_id, Llama4Model, revision, @@ -1038,10 +1039,8 @@ def get_model( speculator=speculator, dtype=torch.bfloat16, trust_remote_code=trust_remote_code, - # how to load from preprocessor_config.json processor_kwargs={ "use_fast": True, - "max_patches": 15, "size": {"height": 336, "width": 336}, }, ) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index b81cc61f..f9eb554c 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -202,7 +202,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): attn_implementation = { "text_config": "tgi", - "vision_config": "eager", + "vision_config": "sdpa", } model = model_class.from_pretrained( @@ -395,6 +395,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): image_grid_thw=image_grid_thw, attention_mask=inputs.get("attention_mask", None), use_sdpa=inputs.get("use_sdpa", False), + cache_position=inputs.get("cache_position", None) ).logits logits = self.post_process_outputs(logits, lm_head_indices) @@ -555,3 +556,12 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): inputs["use_sdpa"] = True return inputs + + +class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): + def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + inputs = super().pre_process_inputs( + input_ids, position_ids, cu_seqlen_prefill + ) + inputs["cache_position"] = position_ids + return inputs \ No newline at end of file