diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 98b5d6a7..ddae0a96 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -139,10 +139,21 @@ except ImportError as e: log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False + VLM_BATCH_TYPES = set() if FLASH_ATTENTION: __all__.append(FlashCausalLM) + from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + ) + + VLM_BATCH_TYPES = { + PaliGemmaBatch, + FlashVlmCausalLMBatch, + FlashMllamaCausalLMBatch, + } + class ModelType(enum.Enum): DEEPSEEK_V2 = { @@ -848,6 +859,11 @@ def get_model( from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) + from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + ) + + VLM_BATCH_TYPES.add(VlmCausalLMBatch) from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index ca354934..e8485df6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -22,7 +22,6 @@ import torch.utils.checkpoint from torch import nn import torch.nn.functional as F -from transformers import Llama4TextConfig from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -106,7 +105,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class Llama4TextExperts(nn.Module): - def __init__(self, prefix, config: Llama4TextConfig, weights): + def __init__(self, prefix, config, weights): super().__init__() self.process_group = weights.process_group self.num_experts = config.num_local_experts @@ -263,7 +262,7 @@ class Llama4TextMoe(nn.Module): class Llama4TextRotaryEmbedding(nn.Module): - def __init__(self, config: Llama4TextConfig, device=None): + def __init__(self, config, device=None): super().__init__() # BC: "rope_type" was originally "type" self.rope_type = "llama3" if config.rope_scaling is not None else "default" diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 6d75b46c..f9250115 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -23,25 +23,7 @@ from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens - -try: - from text_generation_server.models.pali_gemma import PaliGemmaBatch - from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch - from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, - ) - from text_generation_server.models.flash_vlm_causal_lm import ( - FlashVlmCausalLMBatch, - ) - - VLM_BATCH_TYPES = { - PaliGemmaBatch, - FlashVlmCausalLMBatch, - FlashMllamaCausalLMBatch, - } -except (ImportError, NotImplementedError): - # These imports can fail on CPU/Non flash. - VLM_BATCH_TYPES = set() +from text_generation_server.models import VLM_BATCH_TYPES from text_generation_server.utils.version import ( is_driver_compatible, diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh index a5c3f5e1..377b56d8 100644 --- a/backends/gaudi/tgi-entrypoint.sh +++ b/backends/gaudi/tgi-entrypoint.sh @@ -7,5 +7,10 @@ if [[ "$*" == *"--sharded true"* ]]; then echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding' export PT_HPU_ENABLE_LAZY_COLLECTIVES=1 fi +# Check if ATTENTION environment variable is set to paged +if [[ "$ATTENTION" == "paged" ]]; then + echo 'ATTENTION=paged detected, installing transformers==4.51.3' + pip install transformers==4.51.3 +fi text-generation-launcher $@