Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 20:45:40 +00:00
parent 7533b993d5
commit 50ecfc625a
4 changed files with 24 additions and 22 deletions

View File

@ -139,10 +139,21 @@ except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False FLASH_ATTENTION = False
VLM_BATCH_TYPES = set()
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM) __all__.append(FlashCausalLM)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
PaliGemmaBatch,
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = { DEEPSEEK_V2 = {
@ -848,6 +859,11 @@ def get_model(
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration, LlavaNextForConditionalGeneration,
) )
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

View File

@ -22,7 +22,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import Llama4TextConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
@ -106,7 +105,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class Llama4TextExperts(nn.Module): class Llama4TextExperts(nn.Module):
def __init__(self, prefix, config: Llama4TextConfig, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
@ -263,7 +262,7 @@ class Llama4TextMoe(nn.Module):
class Llama4TextRotaryEmbedding(nn.Module): class Llama4TextRotaryEmbedding(nn.Module):
def __init__(self, config: Llama4TextConfig, device=None): def __init__(self, config, device=None):
super().__init__() super().__init__()
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
self.rope_type = "llama3" if config.rope_scaling is not None else "default" self.rope_type = "llama3" if config.rope_scaling is not None else "default"

View File

@ -23,25 +23,7 @@ from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional from text_generation_server.utils.tokens import make_tokenizer_optional
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
from text_generation_server.models import VLM_BATCH_TYPES
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
PaliGemmaBatch,
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash.
VLM_BATCH_TYPES = set()
from text_generation_server.utils.version import ( from text_generation_server.utils.version import (
is_driver_compatible, is_driver_compatible,

View File

@ -7,5 +7,10 @@ if [[ "$*" == *"--sharded true"* ]]; then
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding' echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1 export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
fi fi
# Check if ATTENTION environment variable is set to paged
if [[ "$ATTENTION" == "paged" ]]; then
echo 'ATTENTION=paged detected, installing transformers==4.51.3'
pip install transformers==4.51.3
fi
text-generation-launcher $@ text-generation-launcher $@