mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Add support for other vlm
This commit is contained in:
parent
d8d09f9c7b
commit
ac6fc70c75
@ -208,6 +208,8 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.transformers_flash_vlm import (
|
from text_generation_server.models.transformers_flash_vlm import (
|
||||||
TransformersFlashVlmCausalLM,
|
TransformersFlashVlmCausalLM,
|
||||||
|
TransformersQwen2VlmCausalLM,
|
||||||
|
TransformersGemma3VlmCausalLM,
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||||
@ -1161,7 +1163,6 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GEMMA3:
|
elif model_type == GEMMA3:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
# TODO: Use VlmCausalLM when image support is added.
|
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Gemma3ForConditionalGeneration,
|
model_class=Gemma3ForConditionalGeneration,
|
||||||
@ -1179,9 +1180,11 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
return TransformersFlashVlmCausalLM.fallback(
|
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
|
||||||
|
|
||||||
|
return TransformersGemma3VlmCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
# AutoModelForConditionalGeneration,
|
Gemma3Model,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
@ -1490,15 +1493,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == QWEN2_VL:
|
if model_type == QWEN2_VL:
|
||||||
return TransformersFlashVlmCausalLM.fallback(
|
if FLASH_ATTENTION:
|
||||||
model_id,
|
|
||||||
# AutoModelForConditionalGeneration,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2VLForConditionalGeneration,
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
@ -1511,7 +1506,20 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
|
||||||
|
|
||||||
|
return TransformersQwen2VlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Qwen2VLModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
if model_type == QWEN2_5_VL:
|
if model_type == QWEN2_5_VL:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2_5VLForConditionalGeneration,
|
model_class=Qwen2_5VLForConditionalGeneration,
|
||||||
@ -1526,6 +1534,19 @@ def get_model(
|
|||||||
config_class=Qwen2_5_VLConfig,
|
config_class=Qwen2_5_VLConfig,
|
||||||
processor_class=Qwen2_5_VLProcessor,
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
|
||||||
|
return TransformersQwen2VlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Qwen2VLModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
config_class=Qwen2_5_VLConfig,
|
||||||
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return MllamaCausalLM(
|
return MllamaCausalLM(
|
||||||
@ -1540,6 +1561,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import MllamaForConditionalGeneration as MllamaModel
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
MllamaModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=MllamaCausalLMBatch,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
@ -1558,6 +1592,19 @@ def get_model(
|
|||||||
# VRAM usage.
|
# VRAM usage.
|
||||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import Idefics2ForConditionalGeneration as Idefics2Model
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Idefics2Model,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == IDEFICS3:
|
if model_type == IDEFICS3:
|
||||||
@ -1576,20 +1623,23 @@ def get_model(
|
|||||||
# VRAM usage.
|
# VRAM usage.
|
||||||
processor_kwargs={"size": {"longest_edge": 1456}},
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import Idefics3ForConditionalGeneration as Idefics3Model
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
Idefics3Model,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
# return TransformersFlashVlmCausalLM.fallback(
|
if FLASH_ATTENTION:
|
||||||
# model_id,
|
|
||||||
# # AutoModelForConditionalGeneration,
|
|
||||||
# revision,
|
|
||||||
# quantize=quantize,
|
|
||||||
# speculator=speculator,
|
|
||||||
# dtype=torch.bfloat16,
|
|
||||||
# trust_remote_code=trust_remote_code,
|
|
||||||
# batch_class=PaliGemmaBatch,
|
|
||||||
# )
|
|
||||||
# if FLASH_ATTENTION:
|
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=PaliGemmaForConditionalGeneration,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
@ -1604,9 +1654,21 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
batch_class=PaliGemmaBatch,
|
batch_class=PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
# else:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
# raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
PaliGemmaModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
@ -1619,6 +1681,18 @@ def get_model(
|
|||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
from transformers import LlavaNextForConditionalGeneration as LlavaNextModel
|
||||||
|
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
LlavaNextModel,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||||
|
|
||||||
|
@ -17,6 +17,15 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
||||||
|
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
||||||
|
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
||||||
|
# to internal constraints it was not (yet?) possible to circumvent
|
||||||
|
REPLICATED_ATTENTION_MODELS = [
|
||||||
|
"olmo2",
|
||||||
|
"phi3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def tgi_flash_attention_forward(
|
def tgi_flash_attention_forward(
|
||||||
module,
|
module,
|
||||||
@ -35,17 +44,13 @@ def tgi_flash_attention_forward(
|
|||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
|
use_sdpa: Optional[bool] = False,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
logger.info("Using TGI Flash Attention")
|
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||||
@ -55,10 +60,8 @@ def tgi_flash_attention_forward(
|
|||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
# if module.layer_idx == 0:
|
# if module.layer_idx == 0:
|
||||||
# from pdb import set_trace; set_trace()
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attention_mask = None
|
if not use_sdpa:
|
||||||
if attention_mask is None:
|
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query=query_states,
|
query=query_states,
|
||||||
key=key_states,
|
key=key_states,
|
||||||
@ -72,9 +75,6 @@ def tgi_flash_attention_forward(
|
|||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
logger.info("uSING FLASH ATTENTION")
|
|
||||||
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
||||||
max_length = max(lengths)
|
max_length = max(lengths)
|
||||||
attention_mask = attention_mask[:, :, :, :max_length]
|
attention_mask = attention_mask[:, :, :, :max_length]
|
||||||
@ -135,34 +135,62 @@ def tgi_flash_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
|
# Siglip
|
||||||
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = (
|
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = (
|
||||||
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"]
|
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"]
|
||||||
)
|
)
|
||||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"]
|
|
||||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"]
|
# Qwen2VL
|
||||||
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = (
|
|
||||||
tgi_flash_attention_forward
|
|
||||||
)
|
|
||||||
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
"tgi"
|
"tgi"
|
||||||
] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
"eager"
|
"eager"
|
||||||
]
|
]
|
||||||
|
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
|
||||||
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
# Idefics2
|
||||||
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||||
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
"tgi"
|
||||||
# to internal constraints it was not (yet?) possible to circumvent
|
] = transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||||
REPLICATED_ATTENTION_MODELS = [
|
"eager"
|
||||||
"olmo2",
|
|
||||||
"phi3",
|
|
||||||
]
|
]
|
||||||
|
transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
||||||
|
"tgi"
|
||||||
|
] = transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
||||||
|
"eager"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Idefics3
|
||||||
|
transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||||
|
"tgi"
|
||||||
|
] = transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||||
|
"eager"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Clip
|
||||||
|
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["tgi"] = (
|
||||||
|
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["sdpa"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mllama
|
||||||
|
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["tgi"] = (
|
||||||
|
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["eager"]
|
||||||
|
)
|
||||||
|
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
|
||||||
|
# transformers.models.mllama.modeling_mllama.MLLAMA_TEXT_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
|
||||||
|
# transformers.models.mllama.modeling_mllama.MLLAMA_CROSS_ATTENTION_CLASSES["tgi"] = tgi_cross_attention_forward
|
||||||
|
|
||||||
|
# TODO: implement
|
||||||
|
# tgi_cross_attention_forward
|
||||||
|
|
||||||
|
|
||||||
class TransformersFlashVlmCausalLM(VlmCausalLM):
|
class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
model_class,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
@ -175,10 +203,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
batch_class=VlmCausalLMBatch,
|
batch_class=VlmCausalLMBatch,
|
||||||
):
|
):
|
||||||
# # from pdb import set_trace; set_trace()
|
self.batch_class = batch_class
|
||||||
self.batch_class = VlmCausalLMBatch
|
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
@ -204,16 +232,15 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
|
|
||||||
if processor_kwargs is None:
|
if processor_kwargs is None:
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
# processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}
|
|
||||||
self.processor = processor_class.from_pretrained(
|
self.processor = processor_class.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
from transformers import Qwen2VLForConditionalGeneration
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
@ -318,6 +345,100 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
def post_process_outputs(self, logits, lm_head_indices):
|
||||||
|
return logits.squeeze(dim=0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
model_class,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
batch_class: Optional[type] = VlmCausalLMBatch,
|
||||||
|
processor_kwargs: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=model_class,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=batch_class,
|
||||||
|
processor_kwargs=processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
|
inputs = self.pre_process_inputs(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
|
logits = self.model.original_forward(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
position_ids=inputs["position_ids"],
|
||||||
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
return_dict=True,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
|
use_sdpa=inputs.get("use_sdpa", False),
|
||||||
|
).logits
|
||||||
|
|
||||||
|
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||||
|
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):
|
def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):
|
||||||
if image_grid_thw is None:
|
if image_grid_thw is None:
|
||||||
return (
|
return (
|
||||||
@ -391,77 +512,82 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
)
|
)
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
@classmethod
|
def post_process_outputs(self, logits, lm_head_indices):
|
||||||
def fallback(
|
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
||||||
cls,
|
|
||||||
model_id: str,
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
revision: Optional[str] = None,
|
input_ids = input_ids.unsqueeze(0)
|
||||||
quantize: Optional[str] = None,
|
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
||||||
speculator: Optional[str] = None,
|
return {"input_ids": input_ids, "position_ids": position_ids}
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
batch_class: Optional[type] = VlmCausalLMBatch,
|
class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
):
|
def get_attention_mask(self, input_ids, cu_seqlen_prefill):
|
||||||
return cls(
|
device = input_ids.device
|
||||||
model_id=model_id,
|
dtype = self.dtype
|
||||||
revision=revision,
|
min_dtype = torch.finfo(dtype).min
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||||
dtype=dtype,
|
batch_size = len(lengths)
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
batch_class=batch_class,
|
sequence_length = max(lengths)
|
||||||
|
target_length = sequence_length
|
||||||
|
# Create the padding mask from the computed lengths.
|
||||||
|
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||||
|
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
||||||
|
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
|
||||||
|
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||||
|
|
||||||
|
# Build the base causal mask (for non-image tokens):
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones(
|
||||||
|
(sequence_length, sequence_length), dtype=torch.bool, device=device
|
||||||
)
|
)
|
||||||
|
)
|
||||||
def _model_forward(
|
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
||||||
kv_cache: List[KVCache],
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
max_s: int,
|
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
|
||||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
|
||||||
adapter_data=None, # not supported, but passed to match original signature
|
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
pixel_attention_mask=None,
|
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
|
||||||
):
|
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
|
||||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
|
||||||
|
|
||||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
|
||||||
logits = (
|
|
||||||
self.model.original_forward(
|
|
||||||
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
|
||||||
position_ids=position_ids.transpose(0, 1).unsqueeze(
|
|
||||||
1
|
1
|
||||||
), # expand dim to fit Transformers
|
) # [batch, sequence_length, sequence_length]
|
||||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
|
||||||
logits_to_keep=logits_to_keep,
|
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||||
return_dict=True,
|
input_ids.device
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
seqlen=seqlen,
|
|
||||||
max_s=max_s,
|
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
|
||||||
kv_scales=self.kv_scales,
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
pixel_attention_mask=pixel_attention_mask,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
image_grid_thw=image_grid_thw,
|
|
||||||
)
|
|
||||||
.logits.squeeze(dim=0)[lm_head_indices]
|
|
||||||
.unsqueeze(0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# from pdb import set_trace; set_trace()
|
image_token_mask = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
|
||||||
|
)
|
||||||
|
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
return logits, None
|
# Combine the causal base mask and the bidirectional mask.
|
||||||
|
combined_mask = torch.logical_or(
|
||||||
|
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
|
||||||
|
).to(device)
|
||||||
|
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
|
||||||
|
|
||||||
|
full_attention_mask = torch.zeros(
|
||||||
|
(batch_size, 1, sequence_length, target_length),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bool,
|
||||||
|
)
|
||||||
|
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||||
|
|
||||||
|
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||||
|
|
||||||
|
return final_attention_mask
|
||||||
|
|
||||||
|
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
attention_mask = self.get_attention_mask(
|
||||||
|
input_ids.squeeze(0), cu_seqlen_prefill
|
||||||
|
)
|
||||||
|
inputs["attention_mask"] = attention_mask
|
||||||
|
inputs["use_sdpa"] = True
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
Loading…
Reference in New Issue
Block a user