diff --git a/server/Makefile b/server/Makefile index 5257b876..0099c56a 100644 --- a/server/Makefile +++ b/server/Makefile @@ -4,6 +4,7 @@ include Makefile-vllm include Makefile-awq include Makefile-eetq include Makefile-selective-scan +include Makefile-lorax-punica unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 2c066c5c..892943ba 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -85,6 +85,11 @@ def serve( [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] ) + if len(lora_adapter_ids) > 0: + logger.warning( + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + ) + # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 7adfbb29..b6f005ab 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -1,12 +1,13 @@ import math import os -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, List import torch import torch.distributed from accelerate import init_empty_weights from torch import nn from torch.nn import functional as F +from torch.distributed import ProcessGroup from text_generation_server.utils.sgmv import ( add_lora_a_bgmv, @@ -17,16 +18,15 @@ from text_generation_server.utils.sgmv import ( orient_for_rank, ) -LORA = "lora" -MEDUSA = "medusa" - if TYPE_CHECKING: from text_generation_server.adapters import AdapterBatchData from text_generation_server.adapters.lora import BatchLoraWeights class LoraLinear(nn.Module): - def __init__(self, base_layer, layer_id, process_group): + def __init__( + self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup + ): super().__init__() self.base_layer = base_layer self.layer_id = layer_id @@ -41,14 +41,24 @@ class LoraLinear(nn.Module): start_idx: int, end_idx: int, ) -> torch.Tensor: - if adapter_data is None: - return result data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = ( - data.get(LORA) if data is not None else None + data.get("lora") if data is not None else None ) if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + # In tensor-parallel configurations, each GPU processes a specific segment of the output. + # The 'result' tensor represents the full output, which can vary in size based on + # the layer type (e.g., attention vs. feed-forward layers). We define the current + # segment using start_idx and end_idx. If the segment size doesn't match this GPU's + # slice of 'result', we create a zero tensor of the correct size for LoRA computation. + # This approach ensures accurate LoRA application across various layer sizes and + # configurations, adapting to different model architectures and parallelization strategies. + # + # Example scenarios where this is necessary: + # 1. The adapter's size doesn't evenly divide across GPUs. + # 2. We're processing the last segment which might be smaller. + # 3. Different projection layers (q, k, v) have different sizes. if end_idx - start_idx != result.shape[1]: proj = torch.zeros_like(result[:, start_idx:end_idx]) else: @@ -58,55 +68,57 @@ class LoraLinear(nn.Module): lora_a_ptr = rank_segments.lora_a_ptr lora_b_ptr = rank_segments.lora_b_ptr + if lora_a_ptr is None or lora_b_ptr is None: + raise ValueError("LoRA data is missing") + if data.use_sgmv: # Use SGMV for prefill - if lora_a_ptr is not None and lora_b_ptr is not None: - v = lora_a_sgmv_cutlass( - input, - rank_segments.tmp_shrink, - lora_a_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - r, - ) + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) - if self.process_group.size() > 1: - v = self.collect_lora_a(v) + if self.process_group.size() > 1: + v = self.collect_lora_a(v) - lora_b_sgmv_cutlass( - proj, - v, - rank_segments.tmp_expand, - lora_b_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - ) + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) else: # Use BGMV for decode - if lora_a_ptr is not None and lora_b_ptr is not None: - v = torch.zeros( - (input.size(0), r), dtype=input.dtype, device=input.device - ) - add_lora_a_bgmv( - v, - input, - lora_a_ptr, - rank_segments.indices, - self.layer_id, - ) + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + # TODO: error with [-1, 0], but not [0, -1] + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) - if self.process_group.size() > 1: - v = self.collect_lora_a(v) + if self.process_group.size() > 1: + v = self.collect_lora_a(v) - add_lora_b_bgmv( - proj, - v, - lora_b_ptr, - rank_segments.indices, - self.layer_id, - ) + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj @@ -149,13 +161,27 @@ class LoraLinear(nn.Module): class TensorParallelMultiAdapterLinear(LoraLinear): - def __init__(self, base_layer, layer_id, layer_names, sizes, process_group): + def __init__( + self, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): super().__init__(base_layer, layer_id, process_group) self.layer_names = layer_names self.sizes = sizes @classmethod - def load(cls, base_layer, layer_id, layer_names, sizes, process_group): + def load( + cls, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): return TensorParallelMultiAdapterLinear( base_layer, layer_id, layer_names, sizes, process_group ) @@ -165,6 +191,10 @@ class TensorParallelMultiAdapterLinear(LoraLinear): ) -> torch.Tensor: result = self.base_layer(input) + # noop if no layer names are provided (e.g. for models without adapters) + if self.layer_names is None: + return result + # handle models like Bloom that have inputs of shape # (batch_size, sequence_length, hidden_size) # we need to reshape them to (batch_size * sequence_length, hidden_size) @@ -178,7 +208,12 @@ class TensorParallelMultiAdapterLinear(LoraLinear): offset = 0 for i, layer_name in enumerate(self.layer_names): start_idx = offset // self.process_group.size() - + # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple + # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It + # ensures correct slicing of the result tensor, accommodating variations like grouped-query + # attention where k_proj and v_proj differ from q_proj. This allows precise application of + # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the + # different projection sizes across layers and model architectures. if self.sizes is not None: offset += self.sizes[i] end_idx = offset // self.process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 06558379..f787b846 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -56,28 +56,37 @@ if SYSTEM == "rocm": def load_attention(config, prefix, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) + head_size = config.hidden_size // config.num_attention_heads + sizes = None + prefixes = None - # if specific model type, load the correct attention if config.model_type == "phi3": - return TensorParallelColumnLinear.load_qkv( + prefix = f"{prefix}.qkv_proj" + base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=f"{prefix}.qkv_proj", + prefix=prefix, weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( + prefix = f"{prefix}.W_pack" + base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=f"{prefix}.W_pack", + prefix=prefix, weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) else: - # otherwise, load the default attention based on the number of heads + prefixes = ["q_proj", "k_proj", "v_proj"] + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] base_layer = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], @@ -86,18 +95,13 @@ def load_attention(config, prefix, weights, layer_id): bias=bias, ) - head_size = config.hidden_size // config.num_attention_heads - return TensorParallelMultiAdapterLinear.load( - base_layer, - layer_id, - ["q_proj", "k_proj", "v_proj"], - sizes=[ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, - ], - process_group=weights.process_group, - ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) class FlashLlamaAttention(torch.nn.Module): @@ -237,34 +241,39 @@ class LlamaMLP(nn.Module): ), ) ) + prefixes = None + sizes = None + # Fuse gate and up proj bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": - self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( + gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, bias=bias, ) else: + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=bias, ) - self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, - index, - ["gate_proj", "up_proj"], - sizes=[ - config.intermediate_size, - config.intermediate_size, - ], - process_group=weights.process_group, - ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) down_proj = TensorParallelRowLinear.load( config, @@ -273,15 +282,12 @@ class LlamaMLP(nn.Module): bias=bias, ) - if config.model_type == "phi3": - self.down_proj = down_proj - else: - self.down_proj = TensorParallelAdapterRowLinear.load( - down_proj, - index, - "down_proj", - process_group=weights.process_group, - ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ce57fd5c..11eece48 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,6 +1,7 @@ import torch import os from loguru import logger +from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli @@ -30,9 +31,9 @@ def set_model_id(model_id: str): # NOTE: eventually we should move this into the router and pass back the # index in all cases. global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX = None +ADAPTER_TO_INDEX: Dict[str, int] = None -def set_adapter_to_index(adapter_to_index: dict): +def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX ADAPTER_TO_INDEX = adapter_to_index