diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index d6f15465..ac143bb7 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -289,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights): for rank_data in self.rank_data.values() ) - @classmethod - def key(cls) -> str: - return "lora" - @classmethod def load( self, diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 2cc4ba6d..da75dbcd 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC): def has_adapter(self, adapter_index: int) -> bool: pass - @abstractclassmethod - def key(cls) -> str: - pass - @abstractclassmethod def load( cls, @@ -94,7 +90,7 @@ class LayerAdapterWeights: adapter_weights, meta, prefill, prefill_head_indices ) if batched_weights is not None: - batch_data[batch_type.key()] = batched_weights + batch_data = batched_weights return batch_data @@ -126,8 +122,7 @@ class AdapterBatchData: def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation ranks = set() - for layer_data in self.data.values(): - lora_data = layer_data.get("lora") + for lora_data in self.data.values(): if lora_data is None: continue diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 8ec2a5ae..0d679fc8 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -4,9 +4,10 @@ import typer from pathlib import Path from loguru import logger -from typing import Optional +from typing import Optional, List, Dict from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.log import log_master @@ -80,22 +81,16 @@ def serve( if otlp_endpoint is not None: setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) - lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) + lora_adapters = parse_lora_adapters(os.environ.get("LORA_ADAPTERS", None)) - # split on comma and strip whitespace - lora_adapter_ids = ( - [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] - ) - - if len(lora_adapter_ids) > 0: - log_master( - logger.warning, - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", + if len(lora_adapters) > 0: + logger.warning( + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." ) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user - if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: + if len(lora_adapters) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: log_master( logger.warning, f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", @@ -117,7 +112,7 @@ def serve( ) server.serve( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 0bb6db41..df5e92da 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -43,10 +43,7 @@ class LoraLinear(nn.Module): ) -> torch.Tensor: if adapter_data is None: return result - data = adapter_data.data.get(layer_type) - data: Optional["BatchLoraWeights"] = ( - data.get("lora") if data is not None else None - ) + data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) if has_sgmv() and data is not None and data.can_vectorize(self.process_group): # In tensor-parallel configurations, each GPU processes a specific segment of the output. diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c0366623..d674c20e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,7 +6,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional, List +from typing import Optional, List, Dict from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -38,6 +38,7 @@ from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, load_and_merge_adapters, + AdapterInfo, ) from text_generation_server.adapters.lora import LoraWeights @@ -1125,7 +1126,7 @@ def _get_model( # this provides a post model loading hook to load adapters into the model after the model has been loaded def get_model( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -1133,8 +1134,9 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, - adapter_to_index: dict[str, int], + adapter_to_index: Dict[str, int], ): + lora_adapter_ids = [adapter.id for adapter in lora_adapters] model = _get_model( model_id, lora_adapter_ids, @@ -1147,14 +1149,14 @@ def get_model( max_input_tokens, ) - if len(lora_adapter_ids) > 0: + if len(lora_adapters) > 0: target_to_layer = build_layer_weight_lookup(model.model) - for index, adapter_id in enumerate(lora_adapter_ids): + for index, adapter in enumerate(lora_adapters): # currenly we only load one adapter at a time but # this can be extended to merge multiple adapters adapter_parameters = AdapterParameters( - adapter_ids=[adapter_id], + adapter_info=[adapter], weights=None, # will be set to 1 merge_strategy=0, density=1.0, @@ -1162,13 +1164,13 @@ def get_model( ) adapter_index = index + 1 - adapter_to_index[adapter_id] = adapter_index + adapter_to_index[adapter.id] = adapter_index if adapter_index in model.loaded_adapters: continue logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}" ) weight_names = tuple([v[0] for v in target_to_layer.values()]) ( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7455740a..fc199e34 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -9,11 +9,12 @@ from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Dict from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model +from text_generation_server.utils.adapter import AdapterInfo try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -192,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -204,7 +205,7 @@ def serve( ): async def serve_inner( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -227,7 +228,7 @@ def serve( try: model = get_model( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, @@ -274,7 +275,7 @@ def serve( asyncio.run( serve_inner( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 21f7bbbc..1009fc70 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -5,7 +5,7 @@ import warnings from dataclasses import dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple, Optional, List from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -23,9 +23,15 @@ if TYPE_CHECKING: BASE_MODEL_ADAPTER_ID = "__base_model__" +@dataclass +class AdapterInfo: + id: str + path: Optional[str] + + @dataclass class AdapterParameters: - adapter_ids: Tuple[str] + adapter_info: Tuple[AdapterInfo] weights: Tuple[float] merge_strategy: NotImplemented density: float @@ -39,6 +45,22 @@ class AdapterSource: revision: str +def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: + if not lora_adapters: + return [] + + adapter_list = [] + for adapter in lora_adapters.split(","): + parts = adapter.strip().split("=") + if len(parts) == 1: + adapter_list.append(AdapterInfo(id=parts[0], path=None)) + elif len(parts) == 2: + adapter_list.append(AdapterInfo(id=parts[0], path=parts[1])) + else: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + return adapter_list + + def load_and_merge_adapters( model_id: str, adapter_parameters: AdapterParameters, @@ -46,10 +68,13 @@ def load_and_merge_adapters( weight_names: Tuple[str], trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - if len(adapter_parameters.adapter_ids) == 1: + + if len(adapter_parameters.adapter_info) == 1: + adapter_info = next(iter(adapter_parameters.adapter_info)) return load_module_map( model_id, - adapter_parameters.adapter_ids[0], + adapter_info.id, + adapter_info.path, weight_names, trust_remote_code, ) @@ -79,14 +104,15 @@ def _load_and_merge( adapters_to_merge = [] merged_weight_names = set() tokenizer = None - for adapter_id in params.adapter_ids: - if adapter_id == BASE_MODEL_ADAPTER_ID: + for adapter in params.adapter_info: + if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( load_module_map( model_id, - adapter_id, + adapter.id, + adapter.path, weight_names, trust_remote_code, ) @@ -146,17 +172,23 @@ def check_architectures( def load_module_map( model_id: str, adapter_id: str, + adapter_path: Optional[str], weight_names: Tuple[str], trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: revision = "main" - adapter_config = LoraConfig.load(adapter_id, None) - if adapter_config.base_model_name_or_path != model_id: + adapter_config = LoraConfig.load(adapter_path or adapter_id, None) + + if not adapter_path and adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) - adapter_filenames = hub._cached_adapter_weight_files( - adapter_id, revision=revision, extension=".safetensors" + adapter_filenames = ( + hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors") + if adapter_path + else hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) ) try: