mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Remove unnecessary modifications
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
3aa882337e
commit
4e95db304f
@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
|||||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
#CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
|
|||||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||||
|
|
||||||
image:
|
image:
|
||||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
|
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||||
|
|
||||||
run-local-dev-container:
|
run-local-dev-container:
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
@ -57,7 +57,7 @@ def serve(
|
|||||||
), "MASTER_PORT must be set when sharded is True"
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
#logger.remove()
|
logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
@ -193,7 +193,7 @@ def download_weights(
|
|||||||
merge_lora: bool = False,
|
merge_lora: bool = False,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
#logger.remove()
|
logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
|
@ -37,6 +37,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
self.scoring_func = scoring_func
|
self.scoring_func = scoring_func
|
||||||
self.e_score_correction_bias = e_score_correction_bias
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
n_experts=n_experts,
|
n_experts=n_experts,
|
||||||
@ -51,6 +52,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
name=down_proj_name,
|
name=down_proj_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||||
for i in range(n_experts):
|
for i in range(n_experts):
|
||||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||||
|
@ -16,9 +16,6 @@ import enum
|
|||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
#from text_generation_server.models.causal_lm import CausalLM
|
|
||||||
#from text_generation_server.models.bloom import BLOOM
|
|
||||||
#from text_generation_server.models.starcoder import StarCoder
|
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||||
PhiMoEConfig,
|
PhiMoEConfig,
|
||||||
)
|
)
|
||||||
@ -32,7 +29,6 @@ from text_generation_server.utils.adapter import (
|
|||||||
from text_generation_server.adapters.lora import LoraWeights
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
#from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
@ -47,7 +43,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
if ATTENTION == "paged":
|
if ATTENTION == "paged":
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
print(f"Flash Attention enabled models: {FLASH_ATTENTION}")
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||||
@ -459,9 +455,7 @@ def get_model(
|
|||||||
|
|
||||||
kv_cache_dtype = dtype
|
kv_cache_dtype = dtype
|
||||||
|
|
||||||
print(f"Model type: {model_type}")
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
print(f"Flash Attention enabled models: {model_type}")
|
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
config_dict.get("qk_nope_dim", 128)
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
@ -31,7 +31,6 @@ from text_generation_server.layers.attention import (
|
|||||||
KVCache,
|
KVCache,
|
||||||
get_kv_scales,
|
get_kv_scales,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
@ -47,7 +46,6 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -61,11 +59,6 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
def torch_save(tensor, name):
|
|
||||||
# Only save on the main process (rank 0) when using distributed training
|
|
||||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
|
||||||
torch.save(tensor, name)
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights, layer_id):
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
@ -382,7 +375,7 @@ class LlamaMLP(nn.Module):
|
|||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, index, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = index
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.self_attn = FlashLlamaAttention(
|
self.self_attn = FlashLlamaAttention(
|
||||||
index=index,
|
index=index,
|
||||||
@ -443,7 +436,6 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
run_index,
|
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -460,10 +452,6 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
if run_index != -1:
|
|
||||||
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
|
|
||||||
|
|
||||||
if self.residual_multiplier is not None:
|
if self.residual_multiplier is not None:
|
||||||
attn_output *= self.residual_multiplier
|
attn_output *= self.residual_multiplier
|
||||||
|
|
||||||
@ -472,10 +460,6 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
if run_index != -1:
|
|
||||||
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
|
|
||||||
|
|
||||||
|
|
||||||
if self.residual_multiplier is not None:
|
if self.residual_multiplier is not None:
|
||||||
mlp_output *= self.residual_multiplier
|
mlp_output *= self.residual_multiplier
|
||||||
|
|
||||||
@ -485,7 +469,6 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
class FlashLlamaModel(torch.nn.Module):
|
class FlashLlamaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.run_index = -1
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
@ -582,12 +565,11 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
self.run_index,
|
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
self.run_index += 1
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -650,14 +632,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states=None,
|
cross_attention_states=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
|
||||||
|
|
||||||
log_master(
|
|
||||||
logger.debug,
|
|
||||||
f"input_ids: {input_ids}, input_ids.shape={input_ids.shape}, input_ids={input_ids[:-20]}"
|
|
||||||
)
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
print(f"111111111 inputs_embeds: {inputs_embeds}")
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -1520,10 +1520,6 @@ class FlashCausalLM(Model):
|
|||||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||||
logger.info("skip warmup hpu graph, not recommmended")
|
logger.info("skip warmup hpu graph, not recommmended")
|
||||||
del _batch, batch
|
del _batch, batch
|
||||||
print(f"max_input_tokens: {max_input_tokens}")
|
|
||||||
print(f"max_total_tokens: {max_total_tokens}")
|
|
||||||
print(f"num_blocks: {num_blocks}")
|
|
||||||
print(f"BLOCK_SIZE: {BLOCK_SIZE}")
|
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
self.warmup_hpu_graph(batch)
|
self.warmup_hpu_graph(batch)
|
||||||
@ -1796,7 +1792,7 @@ class FlashCausalLM(Model):
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||||
print(f"11111111111111111111input_ids: {input_ids.shape}")
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -24,12 +24,12 @@ from text_generation_server.utils.adapter import AdapterInfo
|
|||||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
#try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
# from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
# VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
# )
|
)
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
@ -39,10 +39,10 @@ VLM_BATCH_TYPES = {
|
|||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
FlashMllamaCausalLMBatch,
|
FlashMllamaCausalLMBatch,
|
||||||
}
|
}
|
||||||
#except (ImportError, NotImplementedError):
|
except (ImportError, NotImplementedError):
|
||||||
# These imports can fail on CPU/Non flash.
|
# These imports can fail on CPU/Non flash.
|
||||||
# print(f"importError: {ImportError}")
|
VLM_BATCH_TYPES = set()
|
||||||
# VLM_BATCH_TYPES = set()
|
|
||||||
from text_generation_server.utils.version import (
|
from text_generation_server.utils.version import (
|
||||||
is_driver_compatible,
|
is_driver_compatible,
|
||||||
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
||||||
@ -110,7 +110,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
if ATTENTION == "paged":
|
if ATTENTION == "paged":
|
||||||
set_max_prefill_tokens(request.max_prefill_tokens)
|
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||||
print(f"VLM_BATCH_TYPES: {VLM_BATCH_TYPES}")
|
|
||||||
if (
|
if (
|
||||||
self.model.batch_type in VLM_BATCH_TYPES
|
self.model.batch_type in VLM_BATCH_TYPES
|
||||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
|
|
||||||
|
|
||||||
def get_hpu_free_memory(device, memory_fraction):
|
def get_hpu_free_memory(device, memory_fraction):
|
||||||
@ -8,7 +7,7 @@ def get_hpu_free_memory(device, memory_fraction):
|
|||||||
|
|
||||||
device_id = device.index
|
device_id = device.index
|
||||||
mem_stats = memory_stats(device_id)
|
mem_stats = memory_stats(device_id)
|
||||||
log_master(logger.debug, f"mem_stats: {mem_stats}")
|
logger.info(f"mem_stats: {mem_stats}")
|
||||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||||
free_memory = max(
|
free_memory = max(
|
||||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||||
|
Loading…
Reference in New Issue
Block a user