mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
refine
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
7533b993d5
commit
50ecfc625a
@ -139,10 +139,21 @@ except ImportError as e:
|
|||||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
VLM_BATCH_TYPES = set()
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashCausalLM)
|
__all__.append(FlashCausalLM)
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
|
FlashVlmCausalLMBatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
VLM_BATCH_TYPES = {
|
||||||
|
PaliGemmaBatch,
|
||||||
|
FlashVlmCausalLMBatch,
|
||||||
|
FlashMllamaCausalLMBatch,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
DEEPSEEK_V2 = {
|
DEEPSEEK_V2 = {
|
||||||
@ -848,6 +859,11 @@ def get_model(
|
|||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
|
VlmCausalLMBatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
|
||||||
|
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from transformers import Llama4TextConfig
|
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
@ -106,7 +105,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
class Llama4TextExperts(nn.Module):
|
class Llama4TextExperts(nn.Module):
|
||||||
def __init__(self, prefix, config: Llama4TextConfig, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
@ -263,7 +262,7 @@ class Llama4TextMoe(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Llama4TextRotaryEmbedding(nn.Module):
|
class Llama4TextRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Llama4TextConfig, device=None):
|
def __init__(self, config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
self.rope_type = "llama3" if config.rope_scaling is not None else "default"
|
self.rope_type = "llama3" if config.rope_scaling is not None else "default"
|
||||||
|
@ -23,25 +23,7 @@ from text_generation_server.models.globals import set_adapter_to_index
|
|||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
from text_generation_server.models import VLM_BATCH_TYPES
|
||||||
try:
|
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
|
||||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
|
||||||
VlmCausalLMBatch,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
|
||||||
FlashVlmCausalLMBatch,
|
|
||||||
)
|
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {
|
|
||||||
PaliGemmaBatch,
|
|
||||||
FlashVlmCausalLMBatch,
|
|
||||||
FlashMllamaCausalLMBatch,
|
|
||||||
}
|
|
||||||
except (ImportError, NotImplementedError):
|
|
||||||
# These imports can fail on CPU/Non flash.
|
|
||||||
VLM_BATCH_TYPES = set()
|
|
||||||
|
|
||||||
from text_generation_server.utils.version import (
|
from text_generation_server.utils.version import (
|
||||||
is_driver_compatible,
|
is_driver_compatible,
|
||||||
|
@ -7,5 +7,10 @@ if [[ "$*" == *"--sharded true"* ]]; then
|
|||||||
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
|
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
|
||||||
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
||||||
fi
|
fi
|
||||||
|
# Check if ATTENTION environment variable is set to paged
|
||||||
|
if [[ "$ATTENTION" == "paged" ]]; then
|
||||||
|
echo 'ATTENTION=paged detected, installing transformers==4.51.3'
|
||||||
|
pip install transformers==4.51.3
|
||||||
|
fi
|
||||||
|
|
||||||
text-generation-launcher $@
|
text-generation-launcher $@
|
||||||
|
Loading…
Reference in New Issue
Block a user