mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Add cache position
This commit is contained in:
parent
22aaf497b7
commit
7a57b01002
@ -193,6 +193,10 @@ RUN cd server && \
|
|||||||
pwd && \
|
pwd && \
|
||||||
text-generation-server --help
|
text-generation-server --help
|
||||||
|
|
||||||
|
RUN uv pip install torchvision --no-deps
|
||||||
|
COPY transformers-4.51.0.dev0-py3-none-any.whl .
|
||||||
|
RUN uv pip install transformers-4.51.0.dev0-py3-none-any.whl --no-deps
|
||||||
|
|
||||||
# Copy build artifacts from flash attention builder
|
# Copy build artifacts from flash attention builder
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
@ -209,6 +209,7 @@ try:
|
|||||||
from text_generation_server.models.transformers_flash_vlm import (
|
from text_generation_server.models.transformers_flash_vlm import (
|
||||||
TransformersFlashVlmCausalLM,
|
TransformersFlashVlmCausalLM,
|
||||||
TransformersGemma3VlmCausalLM,
|
TransformersGemma3VlmCausalLM,
|
||||||
|
TransformersLlama4VlmCausalLM,
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||||
@ -1030,7 +1031,7 @@ def get_model(
|
|||||||
if FLASH_TRANSFORMERS_BACKEND:
|
if FLASH_TRANSFORMERS_BACKEND:
|
||||||
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
||||||
|
|
||||||
return TransformersFlashVlmCausalLM.fallback(
|
return TransformersLlama4VlmCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
Llama4Model,
|
Llama4Model,
|
||||||
revision,
|
revision,
|
||||||
@ -1038,10 +1039,8 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
# how to load from preprocessor_config.json
|
|
||||||
processor_kwargs={
|
processor_kwargs={
|
||||||
"use_fast": True,
|
"use_fast": True,
|
||||||
"max_patches": 15,
|
|
||||||
"size": {"height": 336, "width": 336},
|
"size": {"height": 336, "width": 336},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -202,7 +202,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
|
|
||||||
attn_implementation = {
|
attn_implementation = {
|
||||||
"text_config": "tgi",
|
"text_config": "tgi",
|
||||||
"vision_config": "eager",
|
"vision_config": "sdpa",
|
||||||
}
|
}
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
@ -395,6 +395,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
image_grid_thw=image_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
attention_mask=inputs.get("attention_mask", None),
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
use_sdpa=inputs.get("use_sdpa", False),
|
use_sdpa=inputs.get("use_sdpa", False),
|
||||||
|
cache_position=inputs.get("cache_position", None)
|
||||||
).logits
|
).logits
|
||||||
|
|
||||||
logits = self.post_process_outputs(logits, lm_head_indices)
|
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||||
@ -555,3 +556,12 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
inputs["use_sdpa"] = True
|
inputs["use_sdpa"] = True
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
inputs = super().pre_process_inputs(
|
||||||
|
input_ids, position_ids, cu_seqlen_prefill
|
||||||
|
)
|
||||||
|
inputs["cache_position"] = position_ids
|
||||||
|
return inputs
|
Loading…
Reference in New Issue
Block a user