From 4e95db304fbfb9a89b1707314341e86fe9a5cba9 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Sun, 11 May 2025 18:17:15 +0000 Subject: [PATCH] Remove unnecessary modifications Signed-off-by: yuanwu --- Dockerfile_gaudi | 4 +-- backends/gaudi/Makefile | 2 +- .../server/text_generation_server/cli.py | 4 +-- .../layers/moe/unquantized.py | 2 ++ .../text_generation_server/models/__init__.py | 8 +---- .../custom_modeling/flash_llama_modeling.py | 29 ++------------- .../models/flash_causal_lm.py | 6 +--- .../server/text_generation_server/server.py | 35 +++++++++---------- .../utils/import_utils.py | 3 +- 9 files changed, 29 insertions(+), 64 deletions(-) diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 20c03cb3..06073fe4 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh -#ENTRYPOINT ["/tgi-entrypoint.sh"] -#CMD ["--json-output"] +ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index 43705bb6..f760f4d6 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0 .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install image: - docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy} + docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) run-local-dev-container: docker run -it \ diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index b721bc3c..53837ef7 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -57,7 +57,7 @@ def serve( ), "MASTER_PORT must be set when sharded is True" # Remove default handler - #logger.remove() + logger.remove() logger.add( sys.stdout, format="{message}", @@ -193,7 +193,7 @@ def download_weights( merge_lora: bool = False, ): # Remove default handler - #logger.remove() + logger.remove() logger.add( sys.stdout, format="{message}", diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index 43bc46ce..ec158398 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -37,6 +37,7 @@ class UnquantizedSparseMoELayer(nn.Module): self.weight_block_size = weights.weights_loader.weight_block_size self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, n_experts=n_experts, @@ -51,6 +52,7 @@ class UnquantizedSparseMoELayer(nn.Module): name=down_proj_name, weights=weights, ) + self.hpu_fused_moe = DynamicFusedMOE(n_experts) for i in range(n_experts): self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 530b95d5..d8ea0077 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -16,9 +16,6 @@ import enum from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -#from text_generation_server.models.causal_lm import CausalLM -#from text_generation_server.models.bloom import BLOOM -#from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( PhiMoEConfig, ) @@ -32,7 +29,6 @@ from text_generation_server.utils.adapter import ( from text_generation_server.adapters.lora import LoraWeights from text_generation_server.utils.log import log_master -#from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi __all__ = [ "Model", @@ -47,7 +43,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = False if ATTENTION == "paged": FLASH_ATTENTION = True -print(f"Flash Attention enabled models: {FLASH_ATTENTION}") + try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM @@ -459,9 +455,7 @@ def get_model( kv_cache_dtype = dtype - print(f"Model type: {model_type}") if FLASH_ATTENTION: - print(f"Flash Attention enabled models: {model_type}") if model_type == DEEPSEEK_V2: head_size = max( config_dict.get("qk_nope_dim", 128) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7a6e561f..fb1154a4 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -31,7 +31,6 @@ from text_generation_server.layers.attention import ( KVCache, get_kv_scales, ) -from text_generation_server.utils.log import log_master from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.attention import ( paged_attention, @@ -47,7 +46,6 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from loguru import logger from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -61,11 +59,6 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -def torch_save(tensor, name): - # Only save on the main process (rank 0) when using distributed training - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - torch.save(tensor, name) - def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. @@ -382,7 +375,7 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - self.index = index + with no_fp8(weights): self.self_attn = FlashLlamaAttention( index=index, @@ -443,7 +436,6 @@ class FlashLlamaLayer(nn.Module): seqlen, adapter_data, cross_attention_states, - run_index, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -460,10 +452,6 @@ class FlashLlamaLayer(nn.Module): adapter_data, hpu_attention_meta=hpu_attention_meta, ) - - if run_index != -1: - torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt") - if self.residual_multiplier is not None: attn_output *= self.residual_multiplier @@ -472,10 +460,6 @@ class FlashLlamaLayer(nn.Module): ) mlp_output = self.mlp(normed_attn_res_output, adapter_data) - if run_index != -1: - torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt") - - if self.residual_multiplier is not None: mlp_output *= self.residual_multiplier @@ -485,7 +469,6 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.run_index = -1 process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -582,12 +565,11 @@ class FlashLlamaModel(torch.nn.Module): seqlen, adapter_data, cross_attention_states, - self.run_index, hpu_attention_meta=hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) - self.run_index += 1 + return hidden_states @@ -650,14 +632,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - - - log_master( - logger.debug, - f"input_ids: {input_ids}, input_ids.shape={input_ids.shape}, input_ids={input_ids[:-20]}" - ) inputs_embeds = self.embed_tokens(input_ids) - print(f"111111111 inputs_embeds: {inputs_embeds}") hidden_states = self.model( inputs_embeds, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 1f55c27e..ecedd4aa 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1520,10 +1520,6 @@ class FlashCausalLM(Model): if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": logger.info("skip warmup hpu graph, not recommmended") del _batch, batch - print(f"max_input_tokens: {max_input_tokens}") - print(f"max_total_tokens: {max_total_tokens}") - print(f"num_blocks: {num_blocks}") - print(f"BLOCK_SIZE: {BLOCK_SIZE}") return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens self.warmup_hpu_graph(batch) @@ -1796,7 +1792,7 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = batch.prefilling - print(f"11111111111111111111input_ids: {input_ids.shape}") + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index a08662e6..6d75b46c 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -24,25 +24,25 @@ from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens -#try: -from text_generation_server.models.pali_gemma import PaliGemmaBatch -from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch -# from text_generation_server.models.vlm_causal_lm import ( -# VlmCausalLMBatch, -# ) -from text_generation_server.models.flash_vlm_causal_lm import ( - FlashVlmCausalLMBatch, -) +try: + from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch + from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + ) + from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + ) -VLM_BATCH_TYPES = { - PaliGemmaBatch, - FlashVlmCausalLMBatch, - FlashMllamaCausalLMBatch, -} -#except (ImportError, NotImplementedError): + VLM_BATCH_TYPES = { + PaliGemmaBatch, + FlashVlmCausalLMBatch, + FlashMllamaCausalLMBatch, + } +except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. - # print(f"importError: {ImportError}") - # VLM_BATCH_TYPES = set() + VLM_BATCH_TYPES = set() + from text_generation_server.utils.version import ( is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION, @@ -110,7 +110,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Warmup(self, request, context): if ATTENTION == "paged": set_max_prefill_tokens(request.max_prefill_tokens) - print(f"VLM_BATCH_TYPES: {VLM_BATCH_TYPES}") if ( self.model.batch_type in VLM_BATCH_TYPES ): # Hack, i would rather use kwargs in the `from_pb` call diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 2900c25c..22560dd7 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,6 +1,5 @@ import torch from loguru import logger -from text_generation_server.utils.log import log_master def get_hpu_free_memory(device, memory_fraction): @@ -8,7 +7,7 @@ def get_hpu_free_memory(device, memory_fraction): device_id = device.index mem_stats = memory_stats(device_id) - log_master(logger.debug, f"mem_stats: {mem_stats}") + logger.info(f"mem_stats: {mem_stats}") total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] free_memory = max( 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])