feat: improve star coder to support multi lora layers

This commit is contained in:
drbh 2025-01-07 00:21:58 +00:00
parent 5f78ec32a5
commit 31778a6508

View File

@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
)
def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
base_layer = _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
prefixes=prefixes,
dim=0,
weights=weights,
bias=config.use_bias,
)
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights):
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
class Starcoder2Attention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=config.use_bias,
bias=getattr(config, "use_bias", False),
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -305,7 +330,7 @@ class Starcoder2MLP(nn.Module):
class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, index, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
@ -319,19 +344,37 @@ class Starcoder2GatedMLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
prefixes=prefixes,
weights=weights,
dim=0,
bias=config.use_bias,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=config.use_bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
@ -358,7 +401,7 @@ class Starcoder2Layer(nn.Module):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](