From 31778a650849d854ded580af194c98e29ace03fb Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 7 Jan 2025 00:21:58 +0000 Subject: [PATCH] feat: improve star coder to support multi lora layers --- .../flash_starcoder2_modeling.py | 67 +++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index c793982d..571dc48e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -32,6 +32,8 @@ from text_generation_server.layers.attention import ( Seqlen, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig): ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=config.use_bias, ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): @@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights): class Starcoder2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=config.use_bias, + bias=getattr(config, "use_bias", False), ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -305,7 +330,7 @@ class Starcoder2MLP(nn.Module): class Starcoder2GatedMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -319,19 +344,37 @@ class Starcoder2GatedMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=config.use_bias, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=config.use_bias, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -358,7 +401,7 @@ class Starcoder2Layer(nn.Module): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](