mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 05:22:07 +00:00
Add comments for support of models
This commit is contained in:
parent
50ffe00a1a
commit
b5bac0dd2d
@ -208,13 +208,13 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.transformers_flash_vlm import (
|
from text_generation_server.models.transformers_flash_vlm import (
|
||||||
TransformersFlashVlmCausalLM,
|
TransformersFlashVlmCausalLM,
|
||||||
TransformersQwen2VlmCausalLM,
|
|
||||||
TransformersGemma3VlmCausalLM,
|
TransformersGemma3VlmCausalLM,
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||||
FLASH_TRANSFORMERS_BACKEND = False
|
FLASH_TRANSFORMERS_BACKEND = False
|
||||||
|
|
||||||
|
# TODO: remove this, it's a temporary for testing the FLASH_TRANSFORMERS_BACKEND
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
|
|
||||||
@ -1506,18 +1506,21 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
# TODO: Uncomment when transformers is refactored
|
||||||
from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
|
||||||
|
|
||||||
return TransformersQwen2VlmCausalLM.fallback(
|
# return TransformersQwen2VlmCausalLM.fallback(
|
||||||
model_id,
|
# model_id,
|
||||||
Qwen2VLModel,
|
# Qwen2VLModel,
|
||||||
revision,
|
# revision,
|
||||||
quantize=quantize,
|
# quantize=quantize,
|
||||||
speculator=speculator,
|
# speculator=speculator,
|
||||||
dtype=torch.bfloat16,
|
# dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
# trust_remote_code=trust_remote_code,
|
||||||
)
|
# )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
|
||||||
if model_type == QWEN2_5_VL:
|
if model_type == QWEN2_5_VL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
@ -1534,19 +1537,21 @@ def get_model(
|
|||||||
config_class=Qwen2_5_VLConfig,
|
config_class=Qwen2_5_VLConfig,
|
||||||
processor_class=Qwen2_5_VLProcessor,
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
# TODO: Uncomment when transformers is refactored
|
||||||
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
return TransformersQwen2VlmCausalLM.fallback(
|
# return TransformersQwen2VlmCausalLM.fallback(
|
||||||
model_id,
|
# model_id,
|
||||||
Qwen2VLModel,
|
# Qwen2VLModel,
|
||||||
revision,
|
# revision,
|
||||||
quantize=quantize,
|
# quantize=quantize,
|
||||||
speculator=speculator,
|
# speculator=speculator,
|
||||||
dtype=torch.bfloat16,
|
# dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
# trust_remote_code=trust_remote_code,
|
||||||
config_class=Qwen2_5_VLConfig,
|
# config_class=Qwen2_5_VLConfig,
|
||||||
processor_class=Qwen2_5_VLProcessor,
|
# processor_class=Qwen2_5_VLProcessor,
|
||||||
)
|
# )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return MllamaCausalLM(
|
return MllamaCausalLM(
|
||||||
@ -1561,19 +1566,20 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
# TODO: Uncomment when transformers is refactored and cross attn is added
|
||||||
from transformers import MllamaForConditionalGeneration as MllamaModel
|
# elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
# from transformers import MllamaForConditionalGeneration as MllamaModel
|
||||||
|
|
||||||
return TransformersFlashVlmCausalLM.fallback(
|
# return TransformersFlashVlmCausalLM.fallback(
|
||||||
model_id,
|
# model_id,
|
||||||
MllamaModel,
|
# MllamaModel,
|
||||||
revision,
|
# revision,
|
||||||
quantize=quantize,
|
# quantize=quantize,
|
||||||
speculator=speculator,
|
# speculator=speculator,
|
||||||
dtype=torch.bfloat16,
|
# dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
# trust_remote_code=trust_remote_code,
|
||||||
batch_class=MllamaCausalLMBatch,
|
# batch_class=MllamaCausalLMBatch,
|
||||||
)
|
# )
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
|
@ -27,6 +27,12 @@ REPLICATED_ATTENTION_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# # Qwen2VL
|
||||||
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
|
# "tgi"
|
||||||
|
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
|
# "eager"
|
||||||
|
# ]
|
||||||
def tgi_flash_attention_forward(
|
def tgi_flash_attention_forward(
|
||||||
module,
|
module,
|
||||||
query_states: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
|
Loading…
Reference in New Issue
Block a user