Add qwen2 multi lora layers support (#3089)

add qwen2 multi lora layers support to solve problem like https://github.com/huggingface/text-generation-inference/issues/2881, the similar PR are at https://github.com/huggingface/text-generation-inference/pull/2883

Co-authored-by: hjs <hjs@pku.edu.cn>
This commit is contained in:
EachSheep 2025-03-10 19:42:59 +08:00 committed by GitHub
parent 58a65f7914
commit bbe218a4f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,8 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -23,18 +25,31 @@ from text_generation_server.layers.layernorm import (
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) base_layer = _load_gqa(config, prefix, weights)
else: else:
return TensorParallelColumnLinear.load_multi( base_layer = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=prefixes,
dim=0, dim=0,
weights=weights, weights=weights,
bias=True, bias=True,
) )
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
@ -52,6 +67,7 @@ def _load_gqa(config, prefix: str, weights):
class Qwen2Attention(torch.nn.Module): class Qwen2Attention(torch.nn.Module):
def __init__( def __init__(
self, self,
index: int,
prefix: str, prefix: str,
config, config,
weights, weights,
@ -83,16 +99,22 @@ class Qwen2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -110,8 +132,9 @@ class Qwen2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -163,11 +186,13 @@ class Qwen2Attention(torch.nn.Module):
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -181,27 +206,45 @@ class Qwen2MLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=prefixes,
weights=weights, weights=weights,
dim=0, dim=0,
bias=False, bias=False,
) )
self.down_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id=index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data)
class Qwen2Layer(nn.Module): class Qwen2Layer(nn.Module):
@ -209,9 +252,9 @@ class Qwen2Layer(nn.Module):
super().__init__() super().__init__()
prefix = f"{prefix}.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
@ -234,6 +277,7 @@ class Qwen2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
normed_hidden_states, residual = self.input_layernorm(hidden_states) normed_hidden_states, residual = self.input_layernorm(hidden_states)
@ -249,12 +293,13 @@ class Qwen2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data
) )
hidden_states = attn_output + residual hidden_states = attn_output + residual
# faster post attention rms norm # faster post attention rms norm
hidden_states, residual = self.post_attention_layernorm(hidden_states) hidden_states, residual = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states, adapter_data)
hidden_states = mlp_output + residual hidden_states = mlp_output + residual
return hidden_states return hidden_states
@ -301,6 +346,7 @@ class Qwen2Model(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -324,6 +370,7 @@ class Qwen2Model(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)
@ -396,6 +443,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]