mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-22 15:50:17 +00:00
hotfix: qwen2 formatting (#3093)
* hotfix: qwen2 formatting * cargo fmt
This commit is contained in:
parent
c5ecc7a4de
commit
124398fa57
@ -1215,16 +1215,16 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
parameters,
|
parameters,
|
||||||
total_time,
|
total_time,
|
||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token,
|
time_per_token,
|
||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub(crate) async fn chat_completions(
|
pub(crate) async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
@ -51,6 +51,7 @@ def load_attention(config, prefix, weights, layer_id):
|
|||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
@ -132,7 +133,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
adapter_data
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -244,7 +245,9 @@ class Qwen2MLP(nn.Module):
|
|||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data)
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Layer(nn.Module):
|
class Qwen2Layer(nn.Module):
|
||||||
@ -254,7 +257,9 @@ class Qwen2Layer(nn.Module):
|
|||||||
self.self_attn = Qwen2Attention(
|
self.self_attn = Qwen2Attention(
|
||||||
index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights
|
index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id)
|
self.mlp = Qwen2MLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||||
|
)
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
@ -293,7 +298,7 @@ class Qwen2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
adapter_data
|
adapter_data,
|
||||||
)
|
)
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user