mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
Remove unnecessary modifications
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
3aa882337e
commit
4e95db304f
@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
#CMD ["--json-output"]
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
|
||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||
|
||||
image:
|
||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
|
||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||
|
||||
run-local-dev-container:
|
||||
docker run -it \
|
||||
|
@ -57,7 +57,7 @@ def serve(
|
||||
), "MASTER_PORT must be set when sharded is True"
|
||||
|
||||
# Remove default handler
|
||||
#logger.remove()
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
@ -193,7 +193,7 @@ def download_weights(
|
||||
merge_lora: bool = False,
|
||||
):
|
||||
# Remove default handler
|
||||
#logger.remove()
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
|
@ -37,6 +37,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
|
||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
@ -51,6 +52,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
name=down_proj_name,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||
for i in range(n_experts):
|
||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
|
@ -16,9 +16,6 @@ import enum
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
#from text_generation_server.models.causal_lm import CausalLM
|
||||
#from text_generation_server.models.bloom import BLOOM
|
||||
#from text_generation_server.models.starcoder import StarCoder
|
||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||
PhiMoEConfig,
|
||||
)
|
||||
@ -32,7 +29,6 @@ from text_generation_server.utils.adapter import (
|
||||
from text_generation_server.adapters.lora import LoraWeights
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
#from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
||||
@ -47,7 +43,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
FLASH_ATTENTION = False
|
||||
if ATTENTION == "paged":
|
||||
FLASH_ATTENTION = True
|
||||
print(f"Flash Attention enabled models: {FLASH_ATTENTION}")
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||
@ -459,9 +455,7 @@ def get_model(
|
||||
|
||||
kv_cache_dtype = dtype
|
||||
|
||||
print(f"Model type: {model_type}")
|
||||
if FLASH_ATTENTION:
|
||||
print(f"Flash Attention enabled models: {model_type}")
|
||||
if model_type == DEEPSEEK_V2:
|
||||
head_size = max(
|
||||
config_dict.get("qk_nope_dim", 128)
|
||||
|
@ -31,7 +31,6 @@ from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
get_kv_scales,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
@ -47,7 +46,6 @@ from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -61,11 +59,6 @@ from text_generation_server.utils.weights import (
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
def torch_save(tensor, name):
|
||||
# Only save on the main process (rank 0) when using distributed training
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
torch.save(tensor, name)
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
@ -382,7 +375,7 @@ class LlamaMLP(nn.Module):
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
with no_fp8(weights):
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
@ -443,7 +436,6 @@ class FlashLlamaLayer(nn.Module):
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
run_index,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -460,10 +452,6 @@ class FlashLlamaLayer(nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
if run_index != -1:
|
||||
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
|
||||
|
||||
if self.residual_multiplier is not None:
|
||||
attn_output *= self.residual_multiplier
|
||||
|
||||
@ -472,10 +460,6 @@ class FlashLlamaLayer(nn.Module):
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
if run_index != -1:
|
||||
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
|
||||
|
||||
|
||||
if self.residual_multiplier is not None:
|
||||
mlp_output *= self.residual_multiplier
|
||||
|
||||
@ -485,7 +469,6 @@ class FlashLlamaLayer(nn.Module):
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.run_index = -1
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
@ -582,12 +565,11 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
self.run_index,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
self.run_index += 1
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -650,14 +632,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
|
||||
log_master(
|
||||
logger.debug,
|
||||
f"input_ids: {input_ids}, input_ids.shape={input_ids.shape}, input_ids={input_ids[:-20]}"
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
print(f"111111111 inputs_embeds: {inputs_embeds}")
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
|
@ -1520,10 +1520,6 @@ class FlashCausalLM(Model):
|
||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||
logger.info("skip warmup hpu graph, not recommmended")
|
||||
del _batch, batch
|
||||
print(f"max_input_tokens: {max_input_tokens}")
|
||||
print(f"max_total_tokens: {max_total_tokens}")
|
||||
print(f"num_blocks: {num_blocks}")
|
||||
print(f"BLOCK_SIZE: {BLOCK_SIZE}")
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
self.warmup_hpu_graph(batch)
|
||||
@ -1796,7 +1792,7 @@ class FlashCausalLM(Model):
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||
print(f"11111111111111111111input_ids: {input_ids.shape}")
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -24,25 +24,25 @@ from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||
|
||||
#try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
# from text_generation_server.models.vlm_causal_lm import (
|
||||
# VlmCausalLMBatch,
|
||||
# )
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
#except (ImportError, NotImplementedError):
|
||||
}
|
||||
except (ImportError, NotImplementedError):
|
||||
# These imports can fail on CPU/Non flash.
|
||||
# print(f"importError: {ImportError}")
|
||||
# VLM_BATCH_TYPES = set()
|
||||
VLM_BATCH_TYPES = set()
|
||||
|
||||
from text_generation_server.utils.version import (
|
||||
is_driver_compatible,
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
||||
@ -110,7 +110,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
async def Warmup(self, request, context):
|
||||
if ATTENTION == "paged":
|
||||
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||
print(f"VLM_BATCH_TYPES: {VLM_BATCH_TYPES}")
|
||||
if (
|
||||
self.model.batch_type in VLM_BATCH_TYPES
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
|
||||
def get_hpu_free_memory(device, memory_fraction):
|
||||
@ -8,7 +7,7 @@ def get_hpu_free_memory(device, memory_fraction):
|
||||
|
||||
device_id = device.index
|
||||
mem_stats = memory_stats(device_id)
|
||||
log_master(logger.debug, f"mem_stats: {mem_stats}")
|
||||
logger.info(f"mem_stats: {mem_stats}")
|
||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||
free_memory = max(
|
||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||
|
Loading…
Reference in New Issue
Block a user