mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Add cache position
This commit is contained in:
parent
22aaf497b7
commit
7a57b01002
@ -193,6 +193,10 @@ RUN cd server && \
|
||||
pwd && \
|
||||
text-generation-server --help
|
||||
|
||||
RUN uv pip install torchvision --no-deps
|
||||
COPY transformers-4.51.0.dev0-py3-none-any.whl .
|
||||
RUN uv pip install transformers-4.51.0.dev0-py3-none-any.whl --no-deps
|
||||
|
||||
# Copy build artifacts from flash attention builder
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||
|
@ -209,6 +209,7 @@ try:
|
||||
from text_generation_server.models.transformers_flash_vlm import (
|
||||
TransformersFlashVlmCausalLM,
|
||||
TransformersGemma3VlmCausalLM,
|
||||
TransformersLlama4VlmCausalLM,
|
||||
)
|
||||
except ImportError as e:
|
||||
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||
@ -1030,7 +1031,7 @@ def get_model(
|
||||
if FLASH_TRANSFORMERS_BACKEND:
|
||||
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
||||
|
||||
return TransformersFlashVlmCausalLM.fallback(
|
||||
return TransformersLlama4VlmCausalLM.fallback(
|
||||
model_id,
|
||||
Llama4Model,
|
||||
revision,
|
||||
@ -1038,10 +1039,8 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# how to load from preprocessor_config.json
|
||||
processor_kwargs={
|
||||
"use_fast": True,
|
||||
"max_patches": 15,
|
||||
"size": {"height": 336, "width": 336},
|
||||
},
|
||||
)
|
||||
|
@ -202,7 +202,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
|
||||
attn_implementation = {
|
||||
"text_config": "tgi",
|
||||
"vision_config": "eager",
|
||||
"vision_config": "sdpa",
|
||||
}
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
@ -395,6 +395,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
use_sdpa=inputs.get("use_sdpa", False),
|
||||
cache_position=inputs.get("cache_position", None)
|
||||
).logits
|
||||
|
||||
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||
@ -555,3 +556,12 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
inputs["use_sdpa"] = True
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||
inputs = super().pre_process_inputs(
|
||||
input_ids, position_ids, cu_seqlen_prefill
|
||||
)
|
||||
inputs["cache_position"] = position_ids
|
||||
return inputs
|
Loading…
Reference in New Issue
Block a user