mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve star coder to support multi lora layers
This commit is contained in:
parent
5f78ec32a5
commit
31778a6508
@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
base_layer = _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=prefixes,
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer=base_layer,
|
||||||
|
layer_id=layer_id,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
class Starcoder2Attention(torch.nn.Module):
|
class Starcoder2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=getattr(config, "use_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -305,7 +330,7 @@ class Starcoder2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Starcoder2GatedMLP(nn.Module):
|
class Starcoder2GatedMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -319,19 +344,37 @@ class Starcoder2GatedMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=prefixes,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
@ -358,7 +401,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"model.layers.{layer_id}"
|
||||||
self.self_attn = Starcoder2Attention(
|
self.self_attn = Starcoder2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||||
|
Loading…
Reference in New Issue
Block a user