From 124398fa57edfc62ba9a3384fe711d78ce1bb3f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 10 Mar 2025 16:19:50 +0100 Subject: [PATCH] hotfix: qwen2 formatting (#3093) * hotfix: qwen2 formatting * cargo fmt --- router/src/server.rs | 20 +++++++++---------- .../custom_modeling/flash_qwen2_modeling.py | 13 ++++++++---- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 0b4da9ad..d201f51a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1215,16 +1215,16 @@ example = json ! ({"error": "Incomplete generation"})), ) )] #[instrument( -skip_all, -fields( -parameters, -total_time, -validation_time, -queue_time, -inference_time, -time_per_token, -seed, -) + skip_all, + fields( + parameters, + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) )] pub(crate) async fn chat_completions( Extension(infer): Extension, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index fc708e58..9d956222 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -51,6 +51,7 @@ def load_attention(config, prefix, weights, layer_id): process_group=weights.process_group, ) + def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -132,7 +133,7 @@ class Qwen2Attention(torch.nn.Module): seqlen, max_s, prefill_cache_indices, - adapter_data + adapter_data, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -244,7 +245,9 @@ class Qwen2MLP(nn.Module): def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class Qwen2Layer(nn.Module): @@ -254,7 +257,9 @@ class Qwen2Layer(nn.Module): self.self_attn = Qwen2Attention( index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights ) - self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id) + self.mlp = Qwen2MLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id + ) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -293,7 +298,7 @@ class Qwen2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, - adapter_data + adapter_data, ) hidden_states = attn_output + residual