Remove useless modification

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-06-06 06:46:06 +00:00
parent eed58b77c3
commit 4a89f59ec7
3 changed files with 2 additions and 23 deletions

View File

@ -64,7 +64,7 @@ def serve(
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
# logger.remove()
logger.remove()
logger.add(
sys.stdout,
format="{message}",
@ -203,7 +203,7 @@ def download_weights(
merge_lora: bool = False,
):
# Remove default handler
# logger.remove()
logger.remove()
logger.add(
sys.stdout,
format="{message}",

View File

@ -69,20 +69,6 @@ class FetchFromCache(torch.nn.Module):
return out
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def attention(
*,
query: torch.Tensor,
@ -95,7 +81,6 @@ def attention(
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
num_key_value_groups: int = 1,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
@ -103,9 +88,6 @@ def attention(
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
attn_output = fsdpa_op(
query,
key,

View File

@ -86,7 +86,6 @@ class Qwen3Attention(nn.Module):
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -161,7 +160,6 @@ class Qwen3Attention(nn.Module):
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
num_key_value_groups=self.num_key_value_groups,
)
# Decode
else:
@ -277,7 +275,6 @@ class Qwen3Model(nn.Module):
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()