mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Add qwen2 multi lora layers support (#3089)
add qwen2 multi lora layers support to solve problem like https://github.com/huggingface/text-generation-inference/issues/2881, the similar PR are at https://github.com/huggingface/text-generation-inference/pull/2883 Co-authored-by: hjs <hjs@pku.edu.cn>
This commit is contained in:
parent
58a65f7914
commit
bbe218a4f7
@ -11,6 +11,8 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
@ -23,18 +25,31 @@ from text_generation_server.layers.layernorm import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
|
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
base_layer = _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=prefixes,
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer=base_layer,
|
||||||
|
layer_id=layer_id,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
@ -52,6 +67,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
class Qwen2Attention(torch.nn.Module):
|
class Qwen2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -83,16 +99,22 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
|
|
||||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -110,8 +132,9 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -163,11 +186,13 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MLP(nn.Module):
|
class Qwen2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, index):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -181,27 +206,45 @@ class Qwen2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=prefixes,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id=index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Layer(nn.Module):
|
class Qwen2Layer(nn.Module):
|
||||||
@ -209,9 +252,9 @@ class Qwen2Layer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"{prefix}.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = Qwen2Attention(
|
self.self_attn = Qwen2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id)
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
@ -234,6 +277,7 @@ class Qwen2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, residual = self.input_layernorm(hidden_states)
|
normed_hidden_states, residual = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
@ -249,12 +293,13 @@ class Qwen2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data
|
||||||
)
|
)
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states)
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states, adapter_data)
|
||||||
hidden_states = mlp_output + residual
|
hidden_states = mlp_output + residual
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -301,6 +346,7 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -324,6 +370,7 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
@ -396,6 +443,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
Loading…
Reference in New Issue
Block a user