mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: enable lora load from directory
This commit is contained in:
parent
70dc958fb8
commit
4b569341e6
@ -289,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||||||
for rank_data in self.rank_data.values()
|
for rank_data in self.rank_data.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key(cls) -> str:
|
|
||||||
return "lora"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
|
@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
|
|||||||
def has_adapter(self, adapter_index: int) -> bool:
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractclassmethod
|
|
||||||
def key(cls) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractclassmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
@ -94,7 +90,7 @@ class LayerAdapterWeights:
|
|||||||
adapter_weights, meta, prefill, prefill_head_indices
|
adapter_weights, meta, prefill, prefill_head_indices
|
||||||
)
|
)
|
||||||
if batched_weights is not None:
|
if batched_weights is not None:
|
||||||
batch_data[batch_type.key()] = batched_weights
|
batch_data = batched_weights
|
||||||
return batch_data
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
@ -126,8 +122,7 @@ class AdapterBatchData:
|
|||||||
def ranks(self) -> Set[int]:
|
def ranks(self) -> Set[int]:
|
||||||
# TODO(travis): refactor to be less coupled to lora implementation
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
ranks = set()
|
ranks = set()
|
||||||
for layer_data in self.data.values():
|
for lora_data in self.data.values():
|
||||||
lora_data = layer_data.get("lora")
|
|
||||||
if lora_data is None:
|
if lora_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -4,9 +4,10 @@ import typer
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
@ -80,22 +81,16 @@ def serve(
|
|||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
lora_adapters = parse_lora_adapters(os.environ.get("LORA_ADAPTERS", None))
|
||||||
|
|
||||||
# split on comma and strip whitespace
|
if len(lora_adapters) > 0:
|
||||||
lora_adapter_ids = (
|
logger.warning(
|
||||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||||
)
|
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
|
||||||
log_master(
|
|
||||||
logger.warning,
|
|
||||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
# and warn the user
|
# and warn the user
|
||||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
if len(lora_adapters) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||||
log_master(
|
log_master(
|
||||||
logger.warning,
|
logger.warning,
|
||||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||||
@ -117,7 +112,7 @@ def serve(
|
|||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
|
@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if adapter_data is None:
|
if adapter_data is None:
|
||||||
return result
|
return result
|
||||||
data = adapter_data.data.get(layer_type)
|
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||||
data: Optional["BatchLoraWeights"] = (
|
|
||||||
data.get("lora") if data is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
@ -38,6 +38,7 @@ from text_generation_server.utils.adapter import (
|
|||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
build_layer_weight_lookup,
|
build_layer_weight_lookup,
|
||||||
load_and_merge_adapters,
|
load_and_merge_adapters,
|
||||||
|
AdapterInfo,
|
||||||
)
|
)
|
||||||
from text_generation_server.adapters.lora import LoraWeights
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
@ -1125,7 +1126,7 @@ def _get_model(
|
|||||||
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
@ -1133,8 +1134,9 @@ def get_model(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
adapter_to_index: dict[str, int],
|
adapter_to_index: Dict[str, int],
|
||||||
):
|
):
|
||||||
|
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||||
model = _get_model(
|
model = _get_model(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapter_ids,
|
||||||
@ -1147,14 +1149,14 @@ def get_model(
|
|||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
if len(lora_adapters) > 0:
|
||||||
target_to_layer = build_layer_weight_lookup(model.model)
|
target_to_layer = build_layer_weight_lookup(model.model)
|
||||||
|
|
||||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
for index, adapter in enumerate(lora_adapters):
|
||||||
# currenly we only load one adapter at a time but
|
# currenly we only load one adapter at a time but
|
||||||
# this can be extended to merge multiple adapters
|
# this can be extended to merge multiple adapters
|
||||||
adapter_parameters = AdapterParameters(
|
adapter_parameters = AdapterParameters(
|
||||||
adapter_ids=[adapter_id],
|
adapter_info=[adapter],
|
||||||
weights=None, # will be set to 1
|
weights=None, # will be set to 1
|
||||||
merge_strategy=0,
|
merge_strategy=0,
|
||||||
density=1.0,
|
density=1.0,
|
||||||
@ -1162,13 +1164,13 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
adapter_index = index + 1
|
adapter_index = index + 1
|
||||||
adapter_to_index[adapter_id] = adapter_index
|
adapter_to_index[adapter.id] = adapter_index
|
||||||
|
|
||||||
if adapter_index in model.loaded_adapters:
|
if adapter_index in model.loaded_adapters:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
||||||
)
|
)
|
||||||
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||||
(
|
(
|
||||||
|
@ -9,11 +9,12 @@ from loguru import logger
|
|||||||
|
|
||||||
from grpc_reflection.v1alpha import reflection
|
from grpc_reflection.v1alpha import reflection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
@ -192,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
@ -204,7 +205,7 @@ def serve(
|
|||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
@ -227,7 +228,7 @@ def serve(
|
|||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
@ -274,7 +275,7 @@ def serve(
|
|||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Set, Tuple
|
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
@ -23,9 +23,15 @@ if TYPE_CHECKING:
|
|||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterInfo:
|
||||||
|
id: str
|
||||||
|
path: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterParameters:
|
class AdapterParameters:
|
||||||
adapter_ids: Tuple[str]
|
adapter_info: Tuple[AdapterInfo]
|
||||||
weights: Tuple[float]
|
weights: Tuple[float]
|
||||||
merge_strategy: NotImplemented
|
merge_strategy: NotImplemented
|
||||||
density: float
|
density: float
|
||||||
@ -39,6 +45,22 @@ class AdapterSource:
|
|||||||
revision: str
|
revision: str
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
||||||
|
if not lora_adapters:
|
||||||
|
return []
|
||||||
|
|
||||||
|
adapter_list = []
|
||||||
|
for adapter in lora_adapters.split(","):
|
||||||
|
parts = adapter.strip().split("=")
|
||||||
|
if len(parts) == 1:
|
||||||
|
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
||||||
|
elif len(parts) == 2:
|
||||||
|
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||||
|
return adapter_list
|
||||||
|
|
||||||
|
|
||||||
def load_and_merge_adapters(
|
def load_and_merge_adapters(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_parameters: AdapterParameters,
|
adapter_parameters: AdapterParameters,
|
||||||
@ -46,10 +68,13 @@ def load_and_merge_adapters(
|
|||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
if len(adapter_parameters.adapter_ids) == 1:
|
|
||||||
|
if len(adapter_parameters.adapter_info) == 1:
|
||||||
|
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_parameters.adapter_ids[0],
|
adapter_info.id,
|
||||||
|
adapter_info.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -79,14 +104,15 @@ def _load_and_merge(
|
|||||||
adapters_to_merge = []
|
adapters_to_merge = []
|
||||||
merged_weight_names = set()
|
merged_weight_names = set()
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
for adapter_id in params.adapter_ids:
|
for adapter in params.adapter_info:
|
||||||
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||||
raise ValueError("Base model adapter cannot be merged.")
|
raise ValueError("Base model adapter cannot be merged.")
|
||||||
|
|
||||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||||
load_module_map(
|
load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_id,
|
adapter.id,
|
||||||
|
adapter.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -146,17 +172,23 @@ def check_architectures(
|
|||||||
def load_module_map(
|
def load_module_map(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_id: str,
|
adapter_id: str,
|
||||||
|
adapter_path: Optional[str],
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_id, None)
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||||
if adapter_config.base_model_name_or_path != model_id:
|
|
||||||
|
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||||
|
|
||||||
adapter_filenames = hub._cached_adapter_weight_files(
|
adapter_filenames = (
|
||||||
adapter_id, revision=revision, extension=".safetensors"
|
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||||
|
if adapter_path
|
||||||
|
else hub._cached_adapter_weight_files(
|
||||||
|
adapter_id, revision=revision, extension=".safetensors"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user