mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
refine
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
7533b993d5
commit
50ecfc625a
@ -139,10 +139,21 @@ except ImportError as e:
|
||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||
SUPPORTS_WINDOWING = False
|
||||
FLASH_ATTENTION = False
|
||||
VLM_BATCH_TYPES = set()
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
DEEPSEEK_V2 = {
|
||||
@ -848,6 +859,11 @@ def get_model(
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
|
||||
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
|
@ -22,7 +22,6 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import Llama4TextConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
@ -106,7 +105,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
||||
|
||||
class Llama4TextExperts(nn.Module):
|
||||
def __init__(self, prefix, config: Llama4TextConfig, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.num_experts = config.num_local_experts
|
||||
@ -263,7 +262,7 @@ class Llama4TextMoe(nn.Module):
|
||||
|
||||
|
||||
class Llama4TextRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Llama4TextConfig, device=None):
|
||||
def __init__(self, config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
self.rope_type = "llama3" if config.rope_scaling is not None else "default"
|
||||
|
@ -23,25 +23,7 @@ from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
except (ImportError, NotImplementedError):
|
||||
# These imports can fail on CPU/Non flash.
|
||||
VLM_BATCH_TYPES = set()
|
||||
from text_generation_server.models import VLM_BATCH_TYPES
|
||||
|
||||
from text_generation_server.utils.version import (
|
||||
is_driver_compatible,
|
||||
|
@ -7,5 +7,10 @@ if [[ "$*" == *"--sharded true"* ]]; then
|
||||
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
|
||||
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
||||
fi
|
||||
# Check if ATTENTION environment variable is set to paged
|
||||
if [[ "$ATTENTION" == "paged" ]]; then
|
||||
echo 'ATTENTION=paged detected, installing transformers==4.51.3'
|
||||
pip install transformers==4.51.3
|
||||
fi
|
||||
|
||||
text-generation-launcher $@
|
||||
|
Loading…
Reference in New Issue
Block a user