feat: enable lora load from directory

This commit is contained in:
drbh 2024-07-09 02:38:50 +00:00
parent 70dc958fb8
commit 4b569341e6
7 changed files with 70 additions and 52 deletions

View File

@ -289,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
for rank_data in self.rank_data.values() for rank_data in self.rank_data.values()
) )
@classmethod
def key(cls) -> str:
return "lora"
@classmethod @classmethod
def load( def load(
self, self,

View File

@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
def has_adapter(self, adapter_index: int) -> bool: def has_adapter(self, adapter_index: int) -> bool:
pass pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod @abstractclassmethod
def load( def load(
cls, cls,
@ -94,7 +90,7 @@ class LayerAdapterWeights:
adapter_weights, meta, prefill, prefill_head_indices adapter_weights, meta, prefill, prefill_head_indices
) )
if batched_weights is not None: if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights batch_data = batched_weights
return batch_data return batch_data
@ -126,8 +122,7 @@ class AdapterBatchData:
def ranks(self) -> Set[int]: def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation # TODO(travis): refactor to be less coupled to lora implementation
ranks = set() ranks = set()
for layer_data in self.data.values(): for lora_data in self.data.values():
lora_data = layer_data.get("lora")
if lora_data is None: if lora_data is None:
continue continue

View File

@ -4,9 +4,10 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional, List, Dict
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
@ -80,22 +81,16 @@ def serve(
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) lora_adapters = parse_lora_adapters(os.environ.get("LORA_ADAPTERS", None))
# split on comma and strip whitespace if len(lora_adapters) > 0:
lora_adapter_ids = ( logger.warning(
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
if len(lora_adapter_ids) > 0:
log_master(
logger.warning,
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
) )
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user # and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: if len(lora_adapters) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
log_master( log_master(
logger.warning, logger.warning,
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
@ -117,7 +112,7 @@ def serve(
) )
server.serve( server.serve(
model_id, model_id,
lora_adapter_ids, lora_adapters,
revision, revision,
sharded, sharded,
quantize, quantize,

View File

@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if adapter_data is None: if adapter_data is None:
return result return result
data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output. # In tensor-parallel configurations, each GPU processes a specific segment of the output.

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -38,6 +38,7 @@ from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
build_layer_weight_lookup, build_layer_weight_lookup,
load_and_merge_adapters, load_and_merge_adapters,
AdapterInfo,
) )
from text_generation_server.adapters.lora import LoraWeights from text_generation_server.adapters.lora import LoraWeights
@ -1125,7 +1126,7 @@ def _get_model(
# this provides a post model loading hook to load adapters into the model after the model has been loaded # this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model( def get_model(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]], lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
@ -1133,8 +1134,9 @@ def get_model(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
adapter_to_index: dict[str, int], adapter_to_index: Dict[str, int],
): ):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = _get_model( model = _get_model(
model_id, model_id,
lora_adapter_ids, lora_adapter_ids,
@ -1147,14 +1149,14 @@ def get_model(
max_input_tokens, max_input_tokens,
) )
if len(lora_adapter_ids) > 0: if len(lora_adapters) > 0:
target_to_layer = build_layer_weight_lookup(model.model) target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter_id in enumerate(lora_adapter_ids): for index, adapter in enumerate(lora_adapters):
# currenly we only load one adapter at a time but # currenly we only load one adapter at a time but
# this can be extended to merge multiple adapters # this can be extended to merge multiple adapters
adapter_parameters = AdapterParameters( adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id], adapter_info=[adapter],
weights=None, # will be set to 1 weights=None, # will be set to 1
merge_strategy=0, merge_strategy=0,
density=1.0, density=1.0,
@ -1162,13 +1164,13 @@ def get_model(
) )
adapter_index = index + 1 adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index adapter_to_index[adapter.id] = adapter_index
if adapter_index in model.loaded_adapters: if adapter_index in model.loaded_adapters:
continue continue
logger.info( logger.info(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
) )
weight_names = tuple([v[0] for v in target_to_layer.values()]) weight_names = tuple([v[0] for v in target_to_layer.values()])
( (

View File

@ -9,11 +9,12 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Dict
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.utils.adapter import AdapterInfo
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -192,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]], lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
@ -204,7 +205,7 @@ def serve(
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]], lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
@ -227,7 +228,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, model_id,
lora_adapter_ids, lora_adapters,
revision, revision,
sharded, sharded,
quantize, quantize,
@ -274,7 +275,7 @@ def serve(
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, model_id,
lora_adapter_ids, lora_adapters,
revision, revision,
sharded, sharded,
quantize, quantize,

View File

@ -5,7 +5,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple from typing import TYPE_CHECKING, Set, Tuple, Optional, List
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
@ -23,9 +23,15 @@ if TYPE_CHECKING:
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterInfo:
id: str
path: Optional[str]
@dataclass @dataclass
class AdapterParameters: class AdapterParameters:
adapter_ids: Tuple[str] adapter_info: Tuple[AdapterInfo]
weights: Tuple[float] weights: Tuple[float]
merge_strategy: NotImplemented merge_strategy: NotImplemented
density: float density: float
@ -39,6 +45,22 @@ class AdapterSource:
revision: str revision: str
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
if not lora_adapters:
return []
adapter_list = []
for adapter in lora_adapters.split(","):
parts = adapter.strip().split("=")
if len(parts) == 1:
adapter_list.append(AdapterInfo(id=parts[0], path=None))
elif len(parts) == 2:
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
else:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
return adapter_list
def load_and_merge_adapters( def load_and_merge_adapters(
model_id: str, model_id: str,
adapter_parameters: AdapterParameters, adapter_parameters: AdapterParameters,
@ -46,10 +68,13 @@ def load_and_merge_adapters(
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
if len(adapter_parameters.adapter_info) == 1:
adapter_info = next(iter(adapter_parameters.adapter_info))
return load_module_map( return load_module_map(
model_id, model_id,
adapter_parameters.adapter_ids[0], adapter_info.id,
adapter_info.path,
weight_names, weight_names,
trust_remote_code, trust_remote_code,
) )
@ -79,14 +104,15 @@ def _load_and_merge(
adapters_to_merge = [] adapters_to_merge = []
merged_weight_names = set() merged_weight_names = set()
tokenizer = None tokenizer = None
for adapter_id in params.adapter_ids: for adapter in params.adapter_info:
if adapter_id == BASE_MODEL_ADAPTER_ID: if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.") raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map( load_module_map(
model_id, model_id,
adapter_id, adapter.id,
adapter.path,
weight_names, weight_names,
trust_remote_code, trust_remote_code,
) )
@ -146,18 +172,24 @@ def check_architectures(
def load_module_map( def load_module_map(
model_id: str, model_id: str,
adapter_id: str, adapter_id: str,
adapter_path: Optional[str],
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main" revision = "main"
adapter_config = LoraConfig.load(adapter_id, None) adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
if adapter_config.base_model_name_or_path != model_id:
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = hub._cached_adapter_weight_files( adapter_filenames = (
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
if adapter_path
else hub._cached_adapter_weight_files(
adapter_id, revision=revision, extension=".safetensors" adapter_id, revision=revision, extension=".safetensors"
) )
)
try: try:
adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_tokenizer = AutoTokenizer.from_pretrained(