feat: enable lora load from directory

This commit is contained in:
drbh 2024-07-09 02:38:50 +00:00
parent 70dc958fb8
commit 4b569341e6
7 changed files with 70 additions and 52 deletions

View File

@ -289,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
for rank_data in self.rank_data.values()
)
@classmethod
def key(cls) -> str:
return "lora"
@classmethod
def load(
self,

View File

@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod
def load(
cls,
@ -94,7 +90,7 @@ class LayerAdapterWeights:
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
batch_data = batched_weights
return batch_data
@ -126,8 +122,7 @@ class AdapterBatchData:
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
for lora_data in self.data.values():
if lora_data is None:
continue

View File

@ -4,9 +4,10 @@ import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from typing import Optional, List, Dict
from enum import Enum
from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master
@ -80,22 +81,16 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
lora_adapters = parse_lora_adapters(os.environ.get("LORA_ADAPTERS", None))
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
log_master(
logger.warning,
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
if len(lora_adapters) > 0:
logger.warning(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
if len(lora_adapters) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
log_master(
logger.warning,
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
@ -117,7 +112,7 @@ def serve(
)
server.serve(
model_id,
lora_adapter_ids,
lora_adapters,
revision,
sharded,
quantize,

View File

@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get("lora") if data is not None else None
)
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List
from typing import Optional, List, Dict
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -38,6 +38,7 @@ from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights
@ -1125,7 +1126,7 @@ def _get_model(
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -1133,8 +1134,9 @@ def get_model(
dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: dict[str, int],
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = _get_model(
model_id,
lora_adapter_ids,
@ -1147,14 +1149,14 @@ def get_model(
max_input_tokens,
)
if len(lora_adapter_ids) > 0:
if len(lora_adapters) > 0:
target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter_id in enumerate(lora_adapter_ids):
for index, adapter in enumerate(lora_adapters):
# currenly we only load one adapter at a time but
# this can be extended to merge multiple adapters
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
adapter_info=[adapter],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
@ -1162,13 +1164,13 @@ def get_model(
)
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
adapter_to_index[adapter.id] = adapter_index
if adapter_index in model.loaded_adapters:
continue
logger.info(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
)
weight_names = tuple([v[0] for v in target_to_layer.values()])
(

View File

@ -9,11 +9,12 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Dict
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.utils.adapter import AdapterInfo
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -192,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_id: str,
lora_adapter_ids: Optional[List[str]],
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -204,7 +205,7 @@ def serve(
):
async def serve_inner(
model_id: str,
lora_adapter_ids: Optional[List[str]],
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
@ -227,7 +228,7 @@ def serve(
try:
model = get_model(
model_id,
lora_adapter_ids,
lora_adapters,
revision,
sharded,
quantize,
@ -274,7 +275,7 @@ def serve(
asyncio.run(
serve_inner(
model_id,
lora_adapter_ids,
lora_adapters,
revision,
sharded,
quantize,

View File

@ -5,7 +5,7 @@
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
@ -23,9 +23,15 @@ if TYPE_CHECKING:
BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterInfo:
id: str
path: Optional[str]
@dataclass
class AdapterParameters:
adapter_ids: Tuple[str]
adapter_info: Tuple[AdapterInfo]
weights: Tuple[float]
merge_strategy: NotImplemented
density: float
@ -39,6 +45,22 @@ class AdapterSource:
revision: str
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
if not lora_adapters:
return []
adapter_list = []
for adapter in lora_adapters.split(","):
parts = adapter.strip().split("=")
if len(parts) == 1:
adapter_list.append(AdapterInfo(id=parts[0], path=None))
elif len(parts) == 2:
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
else:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
return adapter_list
def load_and_merge_adapters(
model_id: str,
adapter_parameters: AdapterParameters,
@ -46,10 +68,13 @@ def load_and_merge_adapters(
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
if len(adapter_parameters.adapter_info) == 1:
adapter_info = next(iter(adapter_parameters.adapter_info))
return load_module_map(
model_id,
adapter_parameters.adapter_ids[0],
adapter_info.id,
adapter_info.path,
weight_names,
trust_remote_code,
)
@ -79,14 +104,15 @@ def _load_and_merge(
adapters_to_merge = []
merged_weight_names = set()
tokenizer = None
for adapter_id in params.adapter_ids:
if adapter_id == BASE_MODEL_ADAPTER_ID:
for adapter in params.adapter_info:
if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter_id,
adapter.id,
adapter.path,
weight_names,
trust_remote_code,
)
@ -146,18 +172,24 @@ def check_architectures(
def load_module_map(
model_id: str,
adapter_id: str,
adapter_path: Optional[str],
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main"
adapter_config = LoraConfig.load(adapter_id, None)
if adapter_config.base_model_name_or_path != model_id:
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = hub._cached_adapter_weight_files(
adapter_filenames = (
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
if adapter_path
else hub._cached_adapter_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(