diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index d6569a1d..fc708e58 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -11,6 +11,8 @@ from text_generation_server.layers.attention import ( Seqlen, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -23,18 +25,31 @@ from text_generation_server.layers.layernorm import ( ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=True, ) - + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 @@ -52,6 +67,7 @@ def _load_gqa(config, prefix: str, weights): class Qwen2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -83,16 +99,22 @@ class Qwen2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -110,8 +132,9 @@ class Qwen2Attention(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -163,11 +186,13 @@ class Qwen2Attention(torch.nn.Module): kv_scales=self.kv_scales, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Qwen2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -181,27 +206,45 @@ class Qwen2MLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=False, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id=index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) class Qwen2Layer(nn.Module): @@ -209,9 +252,9 @@ class Qwen2Layer(nn.Module): super().__init__() prefix = f"{prefix}.layers.{layer_id}" self.self_attn = Qwen2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights ) - self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -234,6 +277,7 @@ class Qwen2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, residual = self.input_layernorm(hidden_states) @@ -249,12 +293,13 @@ class Qwen2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data ) hidden_states = attn_output + residual # faster post attention rms norm hidden_states, residual = self.post_attention_layernorm(hidden_states) - mlp_output = self.mlp(hidden_states) + mlp_output = self.mlp(hidden_states, adapter_data) hidden_states = mlp_output + residual return hidden_states @@ -301,6 +346,7 @@ class Qwen2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -324,6 +370,7 @@ class Qwen2Model(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states) @@ -396,6 +443,7 @@ class Qwen2ForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices]