mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: refactor adapter weight loading and mapping
This commit is contained in:
parent
6aebf44f47
commit
70dc958fb8
@ -4,7 +4,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -31,14 +31,3 @@ class AdapterConfig(ABC):
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: ModuleMap,
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
pass
|
||||
|
@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig):
|
||||
adapter_weight_names.add(lora_b_name)
|
||||
return module_map, adapter_weight_names
|
||||
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
return LoraWeights.load(
|
||||
self,
|
||||
model,
|
||||
module_map,
|
||||
layer_type,
|
||||
unused_weight_names,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||
@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights):
|
||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||
return [BatchLoraWeights]
|
||||
|
||||
# prepare pre-loaded lora weights for use in the model.
|
||||
#
|
||||
# this method processes and organizes lora weights for a specific layer type across all layers:
|
||||
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
||||
# - retrieves weights from `module_map` based on the `layer_type`.
|
||||
# - processes `nlayers` number of layers.
|
||||
# - converts weights to the specified `dtype`.
|
||||
# - shards weights across `world_size` number of processes using the `process_group`.
|
||||
# - maps weights to specific layers using `target_to_layer`.
|
||||
# - tracks `unused_weight_names` to identify any unused weights.
|
||||
#
|
||||
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
||||
# with SGMV or BGMV operations.
|
||||
@classmethod
|
||||
def load(
|
||||
def prepare_weights(
|
||||
cls,
|
||||
config: LoraConfig,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
nlayers: int,
|
||||
dtype: torch.dtype,
|
||||
world_size: int,
|
||||
process_group: ProcessGroup,
|
||||
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
||||
) -> Optional[AdapterWeights]:
|
||||
nlayers = model.get_num_layers_for_type(layer_type)
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
weight_name, layer = model.target_to_layer[key]
|
||||
weight_name, layer = target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
||||
@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights):
|
||||
return None
|
||||
|
||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||
lora_a = lora_a.to(base_device, model.dtype)
|
||||
lora_a = lora_a.to(base_device, dtype)
|
||||
|
||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||
lora_b = lora_b.to(base_device, model.dtype)
|
||||
lora_b = lora_b.to(base_device, dtype)
|
||||
|
||||
scale = get_scaling_factor(
|
||||
config.lora_alpha,
|
||||
@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights):
|
||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [
|
||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||
]
|
||||
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights):
|
||||
*shard_lora_weights(
|
||||
weights_a=lora_a_list,
|
||||
weights_b=lora_b_list,
|
||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||
process_group=model.process_group,
|
||||
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
||||
process_group=process_group,
|
||||
),
|
||||
config,
|
||||
)
|
||||
|
@ -71,13 +71,6 @@ class LayerAdapterWeights:
|
||||
return
|
||||
del self.adapter_weights[adapter_idx]
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
adapter_weights.speculative_tokens
|
||||
for adapter_weights in self.adapter_weights.values()
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.adapter_weights) == 0
|
||||
|
||||
|
@ -33,6 +33,15 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
build_layer_weight_lookup,
|
||||
load_and_merge_adapters,
|
||||
)
|
||||
from text_generation_server.adapters.lora import LoraWeights
|
||||
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
@ -294,7 +303,7 @@ for data in ModelType:
|
||||
__GLOBALS[data.name] = data.value["type"]
|
||||
|
||||
|
||||
def get_model(
|
||||
def _get_model(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
@ -1110,3 +1119,114 @@ def get_model(
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported model type {model_type}")
|
||||
|
||||
|
||||
# get_model wraps the internal _get_model function and adds support for loading adapters
|
||||
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||
def get_model(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: dict[str, int],
|
||||
):
|
||||
model = _get_model(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
target_to_layer = build_layer_weight_lookup(model.model)
|
||||
|
||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||
# currenly we only load one adapter at a time but
|
||||
# this can be extended to merge multiple adapters
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=[adapter_id],
|
||||
weights=None, # will be set to 1
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
|
||||
if adapter_index in model.loaded_adapters:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||
)
|
||||
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_and_merge_adapters(
|
||||
model.model_id,
|
||||
adapter_parameters,
|
||||
adapter_index,
|
||||
weight_names,
|
||||
False,
|
||||
)
|
||||
|
||||
unused_weight_names = adapter_weight_names.copy()
|
||||
|
||||
adapter_layers = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
for layer_name in adapter_layers:
|
||||
nlayers = (
|
||||
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
||||
)
|
||||
adapter_weights = LoraWeights.prepare_weights(
|
||||
config=adapter_config,
|
||||
module_map=module_map,
|
||||
layer_type=layer_name,
|
||||
unused_weight_names=unused_weight_names,
|
||||
nlayers=nlayers,
|
||||
dtype=model.dtype,
|
||||
world_size=model.world_size,
|
||||
process_group=model.process_group,
|
||||
target_to_layer=target_to_layer,
|
||||
)
|
||||
|
||||
if adapter_weights is None:
|
||||
continue
|
||||
|
||||
model.layer_to_adapter_weights[layer_name].add_adapter(
|
||||
adapter_index, adapter_weights
|
||||
)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
logger.warning(
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||
|
||||
model.loaded_adapters.add(adapter_index)
|
||||
|
||||
return model
|
||||
|
@ -1,7 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import itertools
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -1695,72 +1694,3 @@ class FlashCausalLM(Model):
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
@ -1,85 +0,0 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class FlashMistral(FlashCausalLM):
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
@ -4,20 +4,12 @@ import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
from text_generation_server.utils.adapter import (
|
||||
load_and_merge_adapters,
|
||||
AdapterParameters,
|
||||
AdapterSource,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
@ -61,7 +53,6 @@ class Model(ABC):
|
||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||
LayerAdapterWeights
|
||||
)
|
||||
self.target_to_layer = None
|
||||
self.loaded_adapters = set()
|
||||
self.static_adapter_id = adapter_id
|
||||
|
||||
@ -142,140 +133,3 @@ class Model(ABC):
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return False
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 0
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
[
|
||||
weights.max_speculative_tokens
|
||||
for weights in self.layer_to_adapter_weights.values()
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
api_token: str,
|
||||
dynamic: bool = True,
|
||||
):
|
||||
"""Loads adapter weights from disk / host memory on the GPU.
|
||||
|
||||
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
||||
into model. Otherwise, the adapter weights are applied during the forward
|
||||
pass and stored separately from the base model parameters.
|
||||
"""
|
||||
if self.target_to_layer is None:
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
if adapter_index in self.loaded_adapters:
|
||||
# Adapter already loaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if dynamic and not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
||||
)
|
||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_and_merge_adapters(
|
||||
self.model_id,
|
||||
adapter_parameters,
|
||||
adapter_source,
|
||||
adapter_index,
|
||||
weight_names,
|
||||
api_token,
|
||||
False,
|
||||
)
|
||||
|
||||
unused_weight_names = adapter_weight_names.copy()
|
||||
for layer_name in self.adapter_layers:
|
||||
adapter_weights = adapter_config.load_batched_adapter_weights(
|
||||
self,
|
||||
module_map,
|
||||
layer_name,
|
||||
unused_weight_names,
|
||||
dynamic,
|
||||
)
|
||||
|
||||
if adapter_weights is None:
|
||||
continue
|
||||
|
||||
layer_weights = self.layer_to_adapter_weights[layer_name]
|
||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||
|
||||
self.loaded_adapters.add(adapter_index)
|
||||
|
||||
def offload_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
):
|
||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
||||
if adapter_index not in self.loaded_adapters:
|
||||
# Adapter already offloaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
for layer_name in self.adapter_layers:
|
||||
if layer_name in self.layer_to_adapter_weights:
|
||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
||||
|
||||
self.loaded_adapters.remove(adapter_index)
|
||||
|
@ -30,9 +30,6 @@ except (ImportError, NotImplementedError):
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
)
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
@ -238,29 +235,9 @@ def serve(
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||
# TODO: improve non merged adapter loading and long term
|
||||
# improve adapter loading as a whole
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=[adapter_id],
|
||||
weights=None, # will be set to 1
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
None, # adapter_source
|
||||
adapter_index,
|
||||
None, # api_token
|
||||
False, # dynamic
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Set, Tuple
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||
|
||||
from text_generation_server.utils import hub
|
||||
@ -43,34 +42,25 @@ class AdapterSource:
|
||||
def load_and_merge_adapters(
|
||||
model_id: str,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: str,
|
||||
adapter_index: int,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
if len(adapter_parameters.adapter_ids) == 1:
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_parameters.adapter_ids[0],
|
||||
adapter_source,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
|
||||
adapter_params = AdapterParametersContainer(
|
||||
adapter_parameters, adapter_source, adapter_index
|
||||
)
|
||||
return _load_and_merge(
|
||||
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
||||
)
|
||||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParametersContainer:
|
||||
adapter_parameters: AdapterParameters
|
||||
adapter_source: str
|
||||
adapter_index: int
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -82,7 +72,6 @@ def _load_and_merge(
|
||||
model_id: str,
|
||||
adapter_params: AdapterParametersContainer,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
params = adapter_params.adapter_parameters
|
||||
@ -98,9 +87,7 @@ def _load_and_merge(
|
||||
load_module_map(
|
||||
model_id,
|
||||
adapter_id,
|
||||
adapter_params.adapter_source,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
@ -159,14 +146,12 @@ def check_architectures(
|
||||
def load_module_map(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_source: str,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
revision = "main"
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
||||
adapter_config = LoraConfig.load(adapter_id, None)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||
|
||||
@ -177,7 +162,6 @@ def load_module_map(
|
||||
try:
|
||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||
adapter_config.config_path,
|
||||
token=api_token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
@ -194,3 +178,70 @@ def load_module_map(
|
||||
adapter_weights, weight_names
|
||||
)
|
||||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
||||
|
||||
|
||||
def get_attn_weights(i, layer):
|
||||
qkv = layer.self_attn.query_key_value
|
||||
weights = {}
|
||||
|
||||
for k in ["q", "k", "v"]:
|
||||
key = (i, f"{k}_proj")
|
||||
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||
weights[key] = value
|
||||
|
||||
weights[(i, "o_proj")] = (
|
||||
f"model.layers.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def get_mlp_weights(i, layer):
|
||||
weights = {}
|
||||
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
gup = layer.mlp.gate_up_proj
|
||||
|
||||
for k in ["gate", "up", "down"]:
|
||||
key = (i, f"{k}_proj")
|
||||
value = (
|
||||
f"model.layers.{i}.mlp.{k}_proj",
|
||||
gup if k != "down" else layer.mlp.down_proj,
|
||||
)
|
||||
weights[key] = value
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
|
||||
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
|
||||
# containing the weight tensor path and the actual layer object. This mapping is needed
|
||||
# for the lora adapter to know which weights to update when applying the adapter.
|
||||
def build_layer_weight_lookup(model):
|
||||
if hasattr(model, "language_model"):
|
||||
m = model.language_model.model
|
||||
elif hasattr(model, "text_model"):
|
||||
m = model.text_model.model
|
||||
else:
|
||||
m = model.model
|
||||
|
||||
layer_weights = {}
|
||||
|
||||
for i, layer in enumerate(m.layers):
|
||||
attn_weights = get_attn_weights(i, layer)
|
||||
mlp_weights = get_mlp_weights(i, layer)
|
||||
|
||||
layer_weights.update(attn_weights)
|
||||
layer_weights.update(mlp_weights)
|
||||
|
||||
lm_head = None
|
||||
if hasattr(m, "lm_head"):
|
||||
lm_head = m.lm_head
|
||||
elif hasattr(model, "lm_head"):
|
||||
lm_head = model.lm_head
|
||||
|
||||
if lm_head:
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
|
||||
|
||||
return layer_weights
|
||||
|
Loading…
Reference in New Issue
Block a user