Remove unnecessary modifications

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 18:17:15 +00:00
parent 3aa882337e
commit 4e95db304f
9 changed files with 29 additions and 64 deletions

View File

@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
#ENTRYPOINT ["/tgi-entrypoint.sh"]
#CMD ["--json-output"]
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]

View File

@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
image:
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
run-local-dev-container:
docker run -it \

View File

@ -57,7 +57,7 @@ def serve(
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
#logger.remove()
logger.remove()
logger.add(
sys.stdout,
format="{message}",
@ -193,7 +193,7 @@ def download_weights(
merge_lora: bool = False,
):
# Remove default handler
#logger.remove()
logger.remove()
logger.add(
sys.stdout,
format="{message}",

View File

@ -37,6 +37,7 @@ class UnquantizedSparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
@ -51,6 +52,7 @@ class UnquantizedSparseMoELayer(nn.Module):
name=down_proj_name,
weights=weights,
)
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])

View File

@ -16,9 +16,6 @@ import enum
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
#from text_generation_server.models.causal_lm import CausalLM
#from text_generation_server.models.bloom import BLOOM
#from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
@ -32,7 +29,6 @@ from text_generation_server.utils.adapter import (
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
#from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
__all__ = [
"Model",
@ -47,7 +43,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
if ATTENTION == "paged":
FLASH_ATTENTION = True
print(f"Flash Attention enabled models: {FLASH_ATTENTION}")
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
@ -459,9 +455,7 @@ def get_model(
kv_cache_dtype = dtype
print(f"Model type: {model_type}")
if FLASH_ATTENTION:
print(f"Flash Attention enabled models: {model_type}")
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)

View File

@ -31,7 +31,6 @@ from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.utils.log import log_master
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.attention import (
paged_attention,
@ -47,7 +46,6 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from loguru import logger
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -61,11 +59,6 @@ from text_generation_server.utils.weights import (
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
def torch_save(tensor, name):
# Only save on the main process (rank 0) when using distributed training
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
torch.save(tensor, name)
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
@ -382,7 +375,7 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.index = index
with no_fp8(weights):
self.self_attn = FlashLlamaAttention(
index=index,
@ -443,7 +436,6 @@ class FlashLlamaLayer(nn.Module):
seqlen,
adapter_data,
cross_attention_states,
run_index,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -460,10 +452,6 @@ class FlashLlamaLayer(nn.Module):
adapter_data,
hpu_attention_meta=hpu_attention_meta,
)
if run_index != -1:
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
@ -472,10 +460,6 @@ class FlashLlamaLayer(nn.Module):
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if run_index != -1:
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
@ -485,7 +469,6 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.run_index = -1
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
@ -582,12 +565,11 @@ class FlashLlamaModel(torch.nn.Module):
seqlen,
adapter_data,
cross_attention_states,
self.run_index,
hpu_attention_meta=hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
self.run_index += 1
return hidden_states
@ -650,14 +632,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
log_master(
logger.debug,
f"input_ids: {input_ids}, input_ids.shape={input_ids.shape}, input_ids={input_ids[:-20]}"
)
inputs_embeds = self.embed_tokens(input_ids)
print(f"111111111 inputs_embeds: {inputs_embeds}")
hidden_states = self.model(
inputs_embeds,
position_ids,

View File

@ -1520,10 +1520,6 @@ class FlashCausalLM(Model):
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch
print(f"max_input_tokens: {max_input_tokens}")
print(f"max_total_tokens: {max_total_tokens}")
print(f"num_blocks: {num_blocks}")
print(f"BLOCK_SIZE: {BLOCK_SIZE}")
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch)
@ -1796,7 +1792,7 @@ class FlashCausalLM(Model):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling
print(f"11111111111111111111input_ids: {input_ids.shape}")
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,

View File

@ -24,12 +24,12 @@ from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
#try:
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
# from text_generation_server.models.vlm_causal_lm import (
# VlmCausalLMBatch,
# )
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
@ -39,10 +39,10 @@ VLM_BATCH_TYPES = {
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
#except (ImportError, NotImplementedError):
except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash.
# print(f"importError: {ImportError}")
# VLM_BATCH_TYPES = set()
VLM_BATCH_TYPES = set()
from text_generation_server.utils.version import (
is_driver_compatible,
MIN_TGI_GAUDI_SYNAPSE_VERSION,
@ -110,7 +110,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Warmup(self, request, context):
if ATTENTION == "paged":
set_max_prefill_tokens(request.max_prefill_tokens)
print(f"VLM_BATCH_TYPES: {VLM_BATCH_TYPES}")
if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call

View File

@ -1,6 +1,5 @@
import torch
from loguru import logger
from text_generation_server.utils.log import log_master
def get_hpu_free_memory(device, memory_fraction):
@ -8,7 +7,7 @@ def get_hpu_free_memory(device, memory_fraction):
device_id = device.index
mem_stats = memory_stats(device_id)
log_master(logger.debug, f"mem_stats: {mem_stats}")
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max(
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])