mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Remove useless modification
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
eed58b77c3
commit
4a89f59ec7
@ -64,7 +64,7 @@ def serve(
|
|||||||
), "MASTER_PORT must be set when sharded is True"
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
# logger.remove()
|
logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
@ -203,7 +203,7 @@ def download_weights(
|
|||||||
merge_lora: bool = False,
|
merge_lora: bool = False,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
# logger.remove()
|
logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
|
@ -69,20 +69,6 @@ class FetchFromCache(torch.nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
||||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
|
||||||
)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
*,
|
*,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -95,7 +81,6 @@ def attention(
|
|||||||
window_size_left: int = -1,
|
window_size_left: int = -1,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
num_key_value_groups: int = 1,
|
|
||||||
):
|
):
|
||||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
bs = seqlen.input_lengths.shape[0]
|
bs = seqlen.input_lengths.shape[0]
|
||||||
@ -103,9 +88,6 @@ def attention(
|
|||||||
_, kv_head_num, head_size = key.shape
|
_, kv_head_num, head_size = key.shape
|
||||||
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
||||||
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
|
||||||
key = repeat_kv(key, num_key_value_groups)
|
|
||||||
value = repeat_kv(value, num_key_value_groups)
|
|
||||||
attn_output = fsdpa_op(
|
attn_output = fsdpa_op(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
@ -86,7 +86,6 @@ class Qwen3Attention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -161,7 +160,6 @@ class Qwen3Attention(nn.Module):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
num_key_value_groups=self.num_key_value_groups,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -277,7 +275,6 @@ class Qwen3Model(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
|
||||||
lazy_mode = htorch.utils.internal.is_lazy()
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
if lazy_mode:
|
if lazy_mode:
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
Loading…
Reference in New Issue
Block a user