mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: enable lora load from directory
This commit is contained in:
parent
70dc958fb8
commit
4b569341e6
@ -289,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||
for rank_data in self.rank_data.values()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return "lora"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
self,
|
||||
|
@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def key(cls) -> str:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def load(
|
||||
cls,
|
||||
@ -94,7 +90,7 @@ class LayerAdapterWeights:
|
||||
adapter_weights, meta, prefill, prefill_head_indices
|
||||
)
|
||||
if batched_weights is not None:
|
||||
batch_data[batch_type.key()] = batched_weights
|
||||
batch_data = batched_weights
|
||||
return batch_data
|
||||
|
||||
|
||||
@ -126,8 +122,7 @@ class AdapterBatchData:
|
||||
def ranks(self) -> Set[int]:
|
||||
# TODO(travis): refactor to be less coupled to lora implementation
|
||||
ranks = set()
|
||||
for layer_data in self.data.values():
|
||||
lora_data = layer_data.get("lora")
|
||||
for lora_data in self.data.values():
|
||||
if lora_data is None:
|
||||
continue
|
||||
|
||||
|
@ -4,9 +4,10 @@ import typer
|
||||
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict
|
||||
from enum import Enum
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
@ -80,22 +81,16 @@ def serve(
|
||||
if otlp_endpoint is not None:
|
||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||
|
||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||
lora_adapters = parse_lora_adapters(os.environ.get("LORA_ADAPTERS", None))
|
||||
|
||||
# split on comma and strip whitespace
|
||||
lora_adapter_ids = (
|
||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
||||
if len(lora_adapters) > 0:
|
||||
logger.warning(
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||
)
|
||||
|
||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||
# and warn the user
|
||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||
if len(lora_adapters) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||
@ -117,7 +112,7 @@ def serve(
|
||||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get("lora") if data is not None else None
|
||||
)
|
||||
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
@ -38,6 +38,7 @@ from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
build_layer_weight_lookup,
|
||||
load_and_merge_adapters,
|
||||
AdapterInfo,
|
||||
)
|
||||
from text_generation_server.adapters.lora import LoraWeights
|
||||
|
||||
@ -1125,7 +1126,7 @@ def _get_model(
|
||||
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||
def get_model(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
lora_adapters: Optional[List[AdapterInfo]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
@ -1133,8 +1134,9 @@ def get_model(
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: dict[str, int],
|
||||
adapter_to_index: Dict[str, int],
|
||||
):
|
||||
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||
model = _get_model(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
@ -1147,14 +1149,14 @@ def get_model(
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
if len(lora_adapters) > 0:
|
||||
target_to_layer = build_layer_weight_lookup(model.model)
|
||||
|
||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||
for index, adapter in enumerate(lora_adapters):
|
||||
# currenly we only load one adapter at a time but
|
||||
# this can be extended to merge multiple adapters
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=[adapter_id],
|
||||
adapter_info=[adapter],
|
||||
weights=None, # will be set to 1
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
@ -1162,13 +1164,13 @@ def get_model(
|
||||
)
|
||||
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
adapter_to_index[adapter.id] = adapter_index
|
||||
|
||||
if adapter_index in model.loaded_adapters:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
||||
)
|
||||
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||
(
|
||||
|
@ -9,11 +9,12 @@ from loguru import logger
|
||||
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
@ -192,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
lora_adapters: Optional[List[AdapterInfo]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
@ -204,7 +205,7 @@ def serve(
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
lora_adapters: Optional[List[AdapterInfo]],
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
@ -227,7 +228,7 @@ def serve(
|
||||
try:
|
||||
model = get_model(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
@ -274,7 +275,7 @@ def serve(
|
||||
asyncio.run(
|
||||
serve_inner(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
@ -5,7 +5,7 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
@ -23,9 +23,15 @@ if TYPE_CHECKING:
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterInfo:
|
||||
id: str
|
||||
path: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParameters:
|
||||
adapter_ids: Tuple[str]
|
||||
adapter_info: Tuple[AdapterInfo]
|
||||
weights: Tuple[float]
|
||||
merge_strategy: NotImplemented
|
||||
density: float
|
||||
@ -39,6 +45,22 @@ class AdapterSource:
|
||||
revision: str
|
||||
|
||||
|
||||
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
||||
if not lora_adapters:
|
||||
return []
|
||||
|
||||
adapter_list = []
|
||||
for adapter in lora_adapters.split(","):
|
||||
parts = adapter.strip().split("=")
|
||||
if len(parts) == 1:
|
||||
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
||||
elif len(parts) == 2:
|
||||
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
||||
else:
|
||||
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||
return adapter_list
|
||||
|
||||
|
||||
def load_and_merge_adapters(
|
||||
model_id: str,
|
||||
adapter_parameters: AdapterParameters,
|
||||
@ -46,10 +68,13 @@ def load_and_merge_adapters(
|
||||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
if len(adapter_parameters.adapter_ids) == 1:
|
||||
|
||||
if len(adapter_parameters.adapter_info) == 1:
|
||||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_parameters.adapter_ids[0],
|
||||
adapter_info.id,
|
||||
adapter_info.path,
|
||||
weight_names,
|
||||
trust_remote_code,
|
||||
)
|
||||
@ -79,14 +104,15 @@ def _load_and_merge(
|
||||
adapters_to_merge = []
|
||||
merged_weight_names = set()
|
||||
tokenizer = None
|
||||
for adapter_id in params.adapter_ids:
|
||||
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
||||
for adapter in params.adapter_info:
|
||||
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||
raise ValueError("Base model adapter cannot be merged.")
|
||||
|
||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
model_id,
|
||||
adapter_id,
|
||||
adapter.id,
|
||||
adapter.path,
|
||||
weight_names,
|
||||
trust_remote_code,
|
||||
)
|
||||
@ -146,18 +172,24 @@ def check_architectures(
|
||||
def load_module_map(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_path: Optional[str],
|
||||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
revision = "main"
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_id, None)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||
|
||||
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||
|
||||
adapter_filenames = hub._cached_adapter_weight_files(
|
||||
adapter_filenames = (
|
||||
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||
if adapter_path
|
||||
else hub._cached_adapter_weight_files(
|
||||
adapter_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
Loading…
Reference in New Issue
Block a user