Fix crash

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-04 09:28:02 +00:00
parent 3482d7ca82
commit ccddbba752
5 changed files with 447 additions and 1341 deletions

View File

@ -25,7 +25,6 @@ class FastLinear(torch.nn.Module):
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
print(f"input.shape={input.shape}, self.weight={self.weight.shape}")
return F.linear(input, self.weight, self.bias)

View File

@ -37,7 +37,6 @@ class UnquantizedSparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
@ -52,7 +51,6 @@ class UnquantizedSparseMoELayer(nn.Module):
name=down_proj_name,
weights=weights,
)
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])

View File

@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.utils.log import log_master
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.attention import (
paged_attention,
@ -46,6 +47,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from loguru import logger
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -633,7 +635,14 @@ class FlashLlamaForCausalLM(torch.nn.Module):
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
log_master(
logger.debug,
f"input_ids: {input_ids}, input_ids.shape={input_ids.shape}, input_ids={input_ids[:-20]}"
)
inputs_embeds = self.embed_tokens(input_ids)
print(f"111111111 inputs_embeds: {inputs_embeds}")
hidden_states = self.model(
inputs_embeds,
position_ids,

View File

@ -1792,7 +1792,7 @@ class FlashCausalLM(Model):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling
print(f"11111111111111111111input_ids: {input_ids.shape}")
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,