mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Add comments for support of models
This commit is contained in:
parent
50ffe00a1a
commit
b5bac0dd2d
@ -208,13 +208,13 @@ try:
|
||||
)
|
||||
from text_generation_server.models.transformers_flash_vlm import (
|
||||
TransformersFlashVlmCausalLM,
|
||||
TransformersQwen2VlmCausalLM,
|
||||
TransformersGemma3VlmCausalLM,
|
||||
)
|
||||
except ImportError as e:
|
||||
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||
FLASH_TRANSFORMERS_BACKEND = False
|
||||
|
||||
# TODO: remove this, it's a temporary for testing the FLASH_TRANSFORMERS_BACKEND
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
|
||||
@ -1506,18 +1506,21 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
|
||||
# TODO: Uncomment when transformers is refactored
|
||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
|
||||
|
||||
return TransformersQwen2VlmCausalLM.fallback(
|
||||
model_id,
|
||||
Qwen2VLModel,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
# return TransformersQwen2VlmCausalLM.fallback(
|
||||
# model_id,
|
||||
# Qwen2VLModel,
|
||||
# revision,
|
||||
# quantize=quantize,
|
||||
# speculator=speculator,
|
||||
# dtype=torch.bfloat16,
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# )
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
|
||||
if model_type == QWEN2_5_VL:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
@ -1534,19 +1537,21 @@ def get_model(
|
||||
config_class=Qwen2_5_VLConfig,
|
||||
processor_class=Qwen2_5_VLProcessor,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
|
||||
return TransformersQwen2VlmCausalLM.fallback(
|
||||
model_id,
|
||||
Qwen2VLModel,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
config_class=Qwen2_5_VLConfig,
|
||||
processor_class=Qwen2_5_VLProcessor,
|
||||
)
|
||||
# TODO: Uncomment when transformers is refactored
|
||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||
# return TransformersQwen2VlmCausalLM.fallback(
|
||||
# model_id,
|
||||
# Qwen2VLModel,
|
||||
# revision,
|
||||
# quantize=quantize,
|
||||
# speculator=speculator,
|
||||
# dtype=torch.bfloat16,
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# config_class=Qwen2_5_VLConfig,
|
||||
# processor_class=Qwen2_5_VLProcessor,
|
||||
# )
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
|
||||
if model_type == MLLAMA:
|
||||
if FLASH_ATTENTION:
|
||||
return MllamaCausalLM(
|
||||
@ -1561,19 +1566,20 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
from transformers import MllamaForConditionalGeneration as MllamaModel
|
||||
# TODO: Uncomment when transformers is refactored and cross attn is added
|
||||
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||
# from transformers import MllamaForConditionalGeneration as MllamaModel
|
||||
|
||||
return TransformersFlashVlmCausalLM.fallback(
|
||||
model_id,
|
||||
MllamaModel,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=MllamaCausalLMBatch,
|
||||
)
|
||||
# return TransformersFlashVlmCausalLM.fallback(
|
||||
# model_id,
|
||||
# MllamaModel,
|
||||
# revision,
|
||||
# quantize=quantize,
|
||||
# speculator=speculator,
|
||||
# dtype=torch.bfloat16,
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# batch_class=MllamaCausalLMBatch,
|
||||
# )
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||
if model_type == IDEFICS2:
|
||||
|
@ -27,6 +27,12 @@ REPLICATED_ATTENTION_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
# # Qwen2VL
|
||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||
# "tgi"
|
||||
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||
# "eager"
|
||||
# ]
|
||||
def tgi_flash_attention_forward(
|
||||
module,
|
||||
query_states: torch.Tensor,
|
||||
|
Loading…
Reference in New Issue
Block a user