diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 8fa2b263..d4445a13 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -64,7 +64,7 @@ def serve( ), "MASTER_PORT must be set when sharded is True" # Remove default handler - # logger.remove() + logger.remove() logger.add( sys.stdout, format="{message}", @@ -203,7 +203,7 @@ def download_weights( merge_lora: bool = False, ): # Remove default handler - # logger.remove() + logger.remove() logger.add( sys.stdout, format="{message}", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index a856ffb7..5aec87c2 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -69,20 +69,6 @@ class FetchFromCache(torch.nn.Module): return out -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def attention( *, query: torch.Tensor, @@ -95,7 +81,6 @@ def attention( window_size_left: int = -1, causal: bool = True, softcap: Optional[float] = None, - num_key_value_groups: int = 1, ): fsdpa_op = ModuleFusedSDPA(FusedSDPA) bs = seqlen.input_lengths.shape[0] @@ -103,9 +88,6 @@ def attention( _, kv_head_num, head_size = key.shape query = query.view(bs, -1, head_num, head_size).transpose(1, 2) key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2) - value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2) - key = repeat_kv(key, num_key_value_groups) - value = repeat_kv(value, num_key_value_groups) attn_output = fsdpa_op( query, key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 1b4af58a..66a17877 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -86,7 +86,6 @@ class Qwen3Attention(nn.Module): bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -161,7 +160,6 @@ class Qwen3Attention(nn.Module): seqlen=seqlen, softmax_scale=self.softmax_scale, window_size_left=self.max_past, - num_key_value_groups=self.num_key_value_groups, ) # Decode else: @@ -277,7 +275,6 @@ class Qwen3Model(nn.Module): ) residual = None - lazy_mode = htorch.utils.internal.is_lazy() if lazy_mode: htorch.core.mark_step()