mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: refactor adapter weight loading and mapping
This commit is contained in:
parent
6aebf44f47
commit
70dc958fb8
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -31,14 +31,3 @@ class AdapterConfig(ABC):
|
|||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
) -> Tuple[ModuleMap, Set[str]]:
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_batched_adapter_weights(
|
|
||||||
self,
|
|
||||||
model: "Model",
|
|
||||||
module_map: ModuleMap,
|
|
||||||
layer_type: str,
|
|
||||||
unused_weight_names: Set[str],
|
|
||||||
dynamic: bool,
|
|
||||||
) -> Optional[AdapterWeights]:
|
|
||||||
pass
|
|
||||||
|
@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig):
|
|||||||
adapter_weight_names.add(lora_b_name)
|
adapter_weight_names.add(lora_b_name)
|
||||||
return module_map, adapter_weight_names
|
return module_map, adapter_weight_names
|
||||||
|
|
||||||
def load_batched_adapter_weights(
|
|
||||||
self,
|
|
||||||
model: "Model",
|
|
||||||
module_map: Dict[str, Dict],
|
|
||||||
layer_type: str,
|
|
||||||
unused_weight_names: Set[str],
|
|
||||||
dynamic: bool,
|
|
||||||
) -> Optional[AdapterWeights]:
|
|
||||||
return LoraWeights.load(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
module_map,
|
|
||||||
layer_type,
|
|
||||||
unused_weight_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||||
@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights):
|
|||||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||||
return [BatchLoraWeights]
|
return [BatchLoraWeights]
|
||||||
|
|
||||||
|
# prepare pre-loaded lora weights for use in the model.
|
||||||
|
#
|
||||||
|
# this method processes and organizes lora weights for a specific layer type across all layers:
|
||||||
|
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
||||||
|
# - retrieves weights from `module_map` based on the `layer_type`.
|
||||||
|
# - processes `nlayers` number of layers.
|
||||||
|
# - converts weights to the specified `dtype`.
|
||||||
|
# - shards weights across `world_size` number of processes using the `process_group`.
|
||||||
|
# - maps weights to specific layers using `target_to_layer`.
|
||||||
|
# - tracks `unused_weight_names` to identify any unused weights.
|
||||||
|
#
|
||||||
|
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
||||||
|
# with SGMV or BGMV operations.
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def prepare_weights(
|
||||||
cls,
|
cls,
|
||||||
config: LoraConfig,
|
config: LoraConfig,
|
||||||
model: "Model",
|
|
||||||
module_map: Dict[str, Dict],
|
module_map: Dict[str, Dict],
|
||||||
layer_type: str,
|
layer_type: str,
|
||||||
unused_weight_names: Set[str],
|
unused_weight_names: Set[str],
|
||||||
|
nlayers: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
world_size: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
||||||
) -> Optional[AdapterWeights]:
|
) -> Optional[AdapterWeights]:
|
||||||
nlayers = model.get_num_layers_for_type(layer_type)
|
|
||||||
lora_a_list = [None] * nlayers
|
lora_a_list = [None] * nlayers
|
||||||
lora_b_list = [None] * nlayers
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
for layer_id in range(nlayers):
|
for layer_id in range(nlayers):
|
||||||
key = (layer_id, layer_type)
|
key = (layer_id, layer_type)
|
||||||
weight_name, layer = model.target_to_layer[key]
|
weight_name, layer = target_to_layer[key]
|
||||||
base_weight = layer.base_layer.linear.weight
|
base_weight = layer.base_layer.linear.weight
|
||||||
base_device = base_weight.device
|
base_device = base_weight.device
|
||||||
|
|
||||||
@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||||
lora_a = lora_a.to(base_device, model.dtype)
|
lora_a = lora_a.to(base_device, dtype)
|
||||||
|
|
||||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||||
lora_b = lora_b.to(base_device, model.dtype)
|
lora_b = lora_b.to(base_device, dtype)
|
||||||
|
|
||||||
scale = get_scaling_factor(
|
scale = get_scaling_factor(
|
||||||
config.lora_alpha,
|
config.lora_alpha,
|
||||||
@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights):
|
|||||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
# pad lora ranks to be compatible with sgmv
|
# pad lora ranks to be compatible with sgmv
|
||||||
lora_a_list = [
|
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||||
]
|
|
||||||
lora_b_list = [
|
|
||||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
|
||||||
]
|
|
||||||
|
|
||||||
if lora_a_list:
|
if lora_a_list:
|
||||||
# update rank if it was padded
|
# update rank if it was padded
|
||||||
@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights):
|
|||||||
*shard_lora_weights(
|
*shard_lora_weights(
|
||||||
weights_a=lora_a_list,
|
weights_a=lora_a_list,
|
||||||
weights_b=lora_b_list,
|
weights_b=lora_b_list,
|
||||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
||||||
process_group=model.process_group,
|
process_group=process_group,
|
||||||
),
|
),
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
|
@ -71,13 +71,6 @@ class LayerAdapterWeights:
|
|||||||
return
|
return
|
||||||
del self.adapter_weights[adapter_idx]
|
del self.adapter_weights[adapter_idx]
|
||||||
|
|
||||||
@property
|
|
||||||
def max_speculative_tokens(self) -> int:
|
|
||||||
return max(
|
|
||||||
adapter_weights.speculative_tokens
|
|
||||||
for adapter_weights in self.adapter_weights.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return len(self.adapter_weights) == 0
|
return len(self.adapter_weights) == 0
|
||||||
|
|
||||||
|
@ -33,6 +33,15 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
|||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.utils.adapter import (
|
||||||
|
AdapterParameters,
|
||||||
|
build_layer_weight_lookup,
|
||||||
|
load_and_merge_adapters,
|
||||||
|
)
|
||||||
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
@ -294,7 +303,7 @@ for data in ModelType:
|
|||||||
__GLOBALS[data.name] = data.value["type"]
|
__GLOBALS[data.name] = data.value["type"]
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def _get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapter_ids: Optional[List[str]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -1110,3 +1119,114 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported model type {model_type}")
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
# get_model wraps the internal _get_model function and adds support for loading adapters
|
||||||
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||||
|
def get_model(
|
||||||
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[str],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
|
adapter_to_index: dict[str, int],
|
||||||
|
):
|
||||||
|
model = _get_model(
|
||||||
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
max_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(lora_adapter_ids) > 0:
|
||||||
|
target_to_layer = build_layer_weight_lookup(model.model)
|
||||||
|
|
||||||
|
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||||
|
# currenly we only load one adapter at a time but
|
||||||
|
# this can be extended to merge multiple adapters
|
||||||
|
adapter_parameters = AdapterParameters(
|
||||||
|
adapter_ids=[adapter_id],
|
||||||
|
weights=None, # will be set to 1
|
||||||
|
merge_strategy=0,
|
||||||
|
density=1.0,
|
||||||
|
majority_sign_method=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index = index + 1
|
||||||
|
adapter_to_index[adapter_id] = adapter_index
|
||||||
|
|
||||||
|
if adapter_index in model.loaded_adapters:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||||
|
)
|
||||||
|
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||||
|
(
|
||||||
|
module_map,
|
||||||
|
adapter_config,
|
||||||
|
adapter_weight_names,
|
||||||
|
adapter_tokenizer,
|
||||||
|
) = load_and_merge_adapters(
|
||||||
|
model.model_id,
|
||||||
|
adapter_parameters,
|
||||||
|
adapter_index,
|
||||||
|
weight_names,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names = adapter_weight_names.copy()
|
||||||
|
|
||||||
|
adapter_layers = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer_name in adapter_layers:
|
||||||
|
nlayers = (
|
||||||
|
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
||||||
|
)
|
||||||
|
adapter_weights = LoraWeights.prepare_weights(
|
||||||
|
config=adapter_config,
|
||||||
|
module_map=module_map,
|
||||||
|
layer_type=layer_name,
|
||||||
|
unused_weight_names=unused_weight_names,
|
||||||
|
nlayers=nlayers,
|
||||||
|
dtype=model.dtype,
|
||||||
|
world_size=model.world_size,
|
||||||
|
process_group=model.process_group,
|
||||||
|
target_to_layer=target_to_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_weights is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model.layer_to_adapter_weights[layer_name].add_adapter(
|
||||||
|
adapter_index, adapter_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(unused_weight_names) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_tokenizer is not None:
|
||||||
|
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||||
|
|
||||||
|
model.loaded_adapters.add(adapter_index)
|
||||||
|
|
||||||
|
return model
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import itertools
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -1695,72 +1694,3 @@ class FlashCausalLM(Model):
|
|||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
layer_weights = {}
|
|
||||||
|
|
||||||
prefix = "model.layers"
|
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
|
||||||
# that have a language_model inside of the larger model.
|
|
||||||
if hasattr(self.model, "language_model"):
|
|
||||||
_model = self.model.language_model
|
|
||||||
elif hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
|
||||||
else:
|
|
||||||
_model = self.model
|
|
||||||
|
|
||||||
for i, layer in enumerate(_model.model.layers):
|
|
||||||
layer_weights[(i, "q_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "k_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "v_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "o_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
|
||||||
layer.self_attn.o_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this is a hack to avoid the gate_proj for
|
|
||||||
# FlashStarcoder2 that doesnt have these layers
|
|
||||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
|
||||||
layer_weights[(i, "gate_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "up_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "down_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
|
||||||
layer.mlp.down_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
|
||||||
return layer_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return ADAPTER_LAYERS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return ["q_proj", "v_proj"]
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return layer_type in ROW_PARALLEL
|
|
||||||
|
@ -1,85 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import Optional, Tuple, Dict, List
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_LAYERS = [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
]
|
|
||||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(FlashCausalLM):
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
layer_weights = {}
|
|
||||||
|
|
||||||
prefix = "model.layers"
|
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
|
||||||
# that have a language_model inside of the larger model.
|
|
||||||
if hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
|
||||||
else:
|
|
||||||
_model = self.model
|
|
||||||
|
|
||||||
for i, layer in enumerate(_model.model.layers):
|
|
||||||
layer_weights[(i, "q_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "k_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "v_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "o_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
|
||||||
layer.self_attn.o_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this is a hack to avoid the gate_proj for
|
|
||||||
# FlashStarcoder2 that doesnt have these layers
|
|
||||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
|
||||||
layer_weights[(i, "gate_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "up_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "down_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
|
||||||
layer.mlp.down_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
|
||||||
return layer_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return ADAPTER_LAYERS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return ["q_proj", "v_proj"]
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return layer_type in ROW_PARALLEL
|
|
@ -4,20 +4,12 @@ import torch
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
from text_generation_server.utils.adapter import (
|
|
||||||
load_and_merge_adapters,
|
|
||||||
AdapterParameters,
|
|
||||||
AdapterSource,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
@ -61,7 +53,6 @@ class Model(ABC):
|
|||||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||||
LayerAdapterWeights
|
LayerAdapterWeights
|
||||||
)
|
)
|
||||||
self.target_to_layer = None
|
|
||||||
self.loaded_adapters = set()
|
self.loaded_adapters = set()
|
||||||
self.static_adapter_id = adapter_id
|
self.static_adapter_id = adapter_id
|
||||||
|
|
||||||
@ -142,140 +133,3 @@ class Model(ABC):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_speculative_tokens(self) -> int:
|
|
||||||
return max(
|
|
||||||
[
|
|
||||||
weights.max_speculative_tokens
|
|
||||||
for weights in self.layer_to_adapter_weights.values()
|
|
||||||
],
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_adapter(
|
|
||||||
self,
|
|
||||||
adapter_parameters: AdapterParameters,
|
|
||||||
adapter_source: AdapterSource,
|
|
||||||
adapter_index: int,
|
|
||||||
api_token: str,
|
|
||||||
dynamic: bool = True,
|
|
||||||
):
|
|
||||||
"""Loads adapter weights from disk / host memory on the GPU.
|
|
||||||
|
|
||||||
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
|
||||||
into model. Otherwise, the adapter weights are applied during the forward
|
|
||||||
pass and stored separately from the base model parameters.
|
|
||||||
"""
|
|
||||||
if self.target_to_layer is None:
|
|
||||||
self.target_to_layer = self.adapter_target_to_layer()
|
|
||||||
if adapter_index in self.loaded_adapters:
|
|
||||||
# Adapter already loaded
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.supports_adapter_loading:
|
|
||||||
raise ValueError("This model does not support adapter loading.")
|
|
||||||
|
|
||||||
if dynamic and not self.dynamic_adapter_loading_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
||||||
f"and therefore does not support dynamic adapter loading. "
|
|
||||||
f"Please initialize a new model instance from the base model in "
|
|
||||||
f"order to use the dynamic adapter loading feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
log_master(
|
|
||||||
logger.info,
|
|
||||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
|
||||||
)
|
|
||||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
|
||||||
(
|
|
||||||
module_map,
|
|
||||||
adapter_config,
|
|
||||||
adapter_weight_names,
|
|
||||||
adapter_tokenizer,
|
|
||||||
) = load_and_merge_adapters(
|
|
||||||
self.model_id,
|
|
||||||
adapter_parameters,
|
|
||||||
adapter_source,
|
|
||||||
adapter_index,
|
|
||||||
weight_names,
|
|
||||||
api_token,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
unused_weight_names = adapter_weight_names.copy()
|
|
||||||
for layer_name in self.adapter_layers:
|
|
||||||
adapter_weights = adapter_config.load_batched_adapter_weights(
|
|
||||||
self,
|
|
||||||
module_map,
|
|
||||||
layer_name,
|
|
||||||
unused_weight_names,
|
|
||||||
dynamic,
|
|
||||||
)
|
|
||||||
|
|
||||||
if adapter_weights is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
layer_weights = self.layer_to_adapter_weights[layer_name]
|
|
||||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
|
||||||
log_master(
|
|
||||||
logger.warning,
|
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
|
||||||
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
|
||||||
|
|
||||||
self.loaded_adapters.add(adapter_index)
|
|
||||||
|
|
||||||
def offload_adapter(
|
|
||||||
self,
|
|
||||||
adapter_parameters: AdapterParameters,
|
|
||||||
adapter_source: AdapterSource,
|
|
||||||
adapter_index: int,
|
|
||||||
):
|
|
||||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
|
||||||
if adapter_index not in self.loaded_adapters:
|
|
||||||
# Adapter already offloaded
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.supports_adapter_loading:
|
|
||||||
raise ValueError("This model does not support adapter loading.")
|
|
||||||
|
|
||||||
if not self.dynamic_adapter_loading_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
||||||
f"and therefore does not support dynamic adapter loading. "
|
|
||||||
f"Please initialize a new model instance from the base model in "
|
|
||||||
f"order to use the dynamic adapter loading feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_name in self.adapter_layers:
|
|
||||||
if layer_name in self.layer_to_adapter_weights:
|
|
||||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
|
||||||
|
|
||||||
self.loaded_adapters.remove(adapter_index)
|
|
||||||
|
@ -30,9 +30,6 @@ except (ImportError, NotImplementedError):
|
|||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||||
from text_generation_server.utils.adapter import (
|
|
||||||
AdapterParameters,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
@ -238,27 +235,7 @@ def serve(
|
|||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
adapter_to_index,
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
|
||||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
|
||||||
# TODO: improve non merged adapter loading and long term
|
|
||||||
# improve adapter loading as a whole
|
|
||||||
adapter_parameters = AdapterParameters(
|
|
||||||
adapter_ids=[adapter_id],
|
|
||||||
weights=None, # will be set to 1
|
|
||||||
merge_strategy=0,
|
|
||||||
density=1.0,
|
|
||||||
majority_sign_method=0,
|
|
||||||
)
|
|
||||||
adapter_index = index + 1
|
|
||||||
adapter_to_index[adapter_id] = adapter_index
|
|
||||||
model.load_adapter(
|
|
||||||
adapter_parameters,
|
|
||||||
None, # adapter_source
|
|
||||||
adapter_index,
|
|
||||||
None, # api_token
|
|
||||||
False, # dynamic
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Set, Tuple
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.utils.merges.strategies import merge_adapters
|
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||||
|
|
||||||
from text_generation_server.utils import hub
|
from text_generation_server.utils import hub
|
||||||
@ -43,34 +42,25 @@ class AdapterSource:
|
|||||||
def load_and_merge_adapters(
|
def load_and_merge_adapters(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_parameters: AdapterParameters,
|
adapter_parameters: AdapterParameters,
|
||||||
adapter_source: str,
|
|
||||||
adapter_index: int,
|
adapter_index: int,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
api_token: str,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
if len(adapter_parameters.adapter_ids) == 1:
|
if len(adapter_parameters.adapter_ids) == 1:
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_parameters.adapter_ids[0],
|
adapter_parameters.adapter_ids[0],
|
||||||
adapter_source,
|
|
||||||
weight_names,
|
weight_names,
|
||||||
api_token,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_params = AdapterParametersContainer(
|
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||||
adapter_parameters, adapter_source, adapter_index
|
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
||||||
)
|
|
||||||
return _load_and_merge(
|
|
||||||
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterParametersContainer:
|
class AdapterParametersContainer:
|
||||||
adapter_parameters: AdapterParameters
|
adapter_parameters: AdapterParameters
|
||||||
adapter_source: str
|
|
||||||
adapter_index: int
|
adapter_index: int
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
@ -82,7 +72,6 @@ def _load_and_merge(
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_params: AdapterParametersContainer,
|
adapter_params: AdapterParametersContainer,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
api_token: str,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
params = adapter_params.adapter_parameters
|
params = adapter_params.adapter_parameters
|
||||||
@ -98,9 +87,7 @@ def _load_and_merge(
|
|||||||
load_module_map(
|
load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_id,
|
adapter_id,
|
||||||
adapter_params.adapter_source,
|
|
||||||
weight_names,
|
weight_names,
|
||||||
api_token,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -159,14 +146,12 @@ def check_architectures(
|
|||||||
def load_module_map(
|
def load_module_map(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_id: str,
|
adapter_id: str,
|
||||||
adapter_source: str,
|
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
api_token: str,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
adapter_config = LoraConfig.load(adapter_id, None)
|
||||||
if adapter_config.base_model_name_or_path != model_id:
|
if adapter_config.base_model_name_or_path != model_id:
|
||||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||||
|
|
||||||
@ -177,7 +162,6 @@ def load_module_map(
|
|||||||
try:
|
try:
|
||||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
adapter_config.config_path,
|
adapter_config.config_path,
|
||||||
token=api_token,
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -194,3 +178,70 @@ def load_module_map(
|
|||||||
adapter_weights, weight_names
|
adapter_weights, weight_names
|
||||||
)
|
)
|
||||||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_weights(i, layer):
|
||||||
|
qkv = layer.self_attn.query_key_value
|
||||||
|
weights = {}
|
||||||
|
|
||||||
|
for k in ["q", "k", "v"]:
|
||||||
|
key = (i, f"{k}_proj")
|
||||||
|
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||||
|
weights[key] = value
|
||||||
|
|
||||||
|
weights[(i, "o_proj")] = (
|
||||||
|
f"model.layers.{i}.self_attn.o_proj",
|
||||||
|
layer.self_attn.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlp_weights(i, layer):
|
||||||
|
weights = {}
|
||||||
|
|
||||||
|
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||||
|
gup = layer.mlp.gate_up_proj
|
||||||
|
|
||||||
|
for k in ["gate", "up", "down"]:
|
||||||
|
key = (i, f"{k}_proj")
|
||||||
|
value = (
|
||||||
|
f"model.layers.{i}.mlp.{k}_proj",
|
||||||
|
gup if k != "down" else layer.mlp.down_proj,
|
||||||
|
)
|
||||||
|
weights[key] = value
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
|
||||||
|
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
|
||||||
|
# containing the weight tensor path and the actual layer object. This mapping is needed
|
||||||
|
# for the lora adapter to know which weights to update when applying the adapter.
|
||||||
|
def build_layer_weight_lookup(model):
|
||||||
|
if hasattr(model, "language_model"):
|
||||||
|
m = model.language_model.model
|
||||||
|
elif hasattr(model, "text_model"):
|
||||||
|
m = model.text_model.model
|
||||||
|
else:
|
||||||
|
m = model.model
|
||||||
|
|
||||||
|
layer_weights = {}
|
||||||
|
|
||||||
|
for i, layer in enumerate(m.layers):
|
||||||
|
attn_weights = get_attn_weights(i, layer)
|
||||||
|
mlp_weights = get_mlp_weights(i, layer)
|
||||||
|
|
||||||
|
layer_weights.update(attn_weights)
|
||||||
|
layer_weights.update(mlp_weights)
|
||||||
|
|
||||||
|
lm_head = None
|
||||||
|
if hasattr(m, "lm_head"):
|
||||||
|
lm_head = m.lm_head
|
||||||
|
elif hasattr(model, "lm_head"):
|
||||||
|
lm_head = model.lm_head
|
||||||
|
|
||||||
|
if lm_head:
|
||||||
|
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
|
||||||
|
|
||||||
|
return layer_weights
|
||||||
|
Loading…
Reference in New Issue
Block a user