fix: refactor adapter weight loading and mapping

This commit is contained in:
drbh 2024-07-05 15:00:36 +00:00
parent 6aebf44f47
commit 70dc958fb8
9 changed files with 220 additions and 395 deletions

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Set, Tuple
import torch import torch
@ -31,14 +31,3 @@ class AdapterConfig(ABC):
weight_names: Tuple[str], weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]: ) -> Tuple[ModuleMap, Set[str]]:
pass pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass

View File

@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig):
adapter_weight_names.add(lora_b_name) adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod @classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights):
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights] return [BatchLoraWeights]
# prepare pre-loaded lora weights for use in the model.
#
# this method processes and organizes lora weights for a specific layer type across all layers:
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
# - retrieves weights from `module_map` based on the `layer_type`.
# - processes `nlayers` number of layers.
# - converts weights to the specified `dtype`.
# - shards weights across `world_size` number of processes using the `process_group`.
# - maps weights to specific layers using `target_to_layer`.
# - tracks `unused_weight_names` to identify any unused weights.
#
# the method handles weight transposition, scaling, and padding to ensure compatibility
# with SGMV or BGMV operations.
@classmethod @classmethod
def load( def prepare_weights(
cls, cls,
config: LoraConfig, config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict], module_map: Dict[str, Dict],
layer_type: str, layer_type: str,
unused_weight_names: Set[str], unused_weight_names: Set[str],
nlayers: int,
dtype: torch.dtype,
world_size: int,
process_group: ProcessGroup,
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
) -> Optional[AdapterWeights]: ) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key] weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.base_layer.linear.weight
base_device = base_weight.device base_device = base_weight.device
@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights):
return None return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"] lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype) lora_a = lora_a.to(base_device, dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"] lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype) lora_b = lora_b.to(base_device, dtype)
scale = get_scaling_factor( scale = get_scaling_factor(
config.lora_alpha, config.lora_alpha,
@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights):
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv # pad lora ranks to be compatible with sgmv
lora_a_list = [ lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list: if lora_a_list:
# update rank if it was padded # update rank if it was padded
@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights):
*shard_lora_weights( *shard_lora_weights(
weights_a=lora_a_list, weights_a=lora_a_list,
weights_b=lora_b_list, weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1, split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
process_group=model.process_group, process_group=process_group,
), ),
config, config,
) )

View File

@ -71,13 +71,6 @@ class LayerAdapterWeights:
return return
del self.adapter_weights[adapter_idx] del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool: def is_empty(self) -> bool:
return len(self.adapter_weights) == 0 return len(self.adapter_weights) == 0

View File

@ -33,6 +33,15 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
@ -294,7 +303,7 @@ for data in ModelType:
__GLOBALS[data.name] = data.value["type"] __GLOBALS[data.name] = data.value["type"]
def get_model( def _get_model(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]], lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
@ -1110,3 +1119,114 @@ def get_model(
) )
raise ValueError(f"Unsupported model type {model_type}") raise ValueError(f"Unsupported model type {model_type}")
# get_model wraps the internal _get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: dict[str, int],
):
model = _get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
max_input_tokens,
)
if len(lora_adapter_ids) > 0:
target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter_id in enumerate(lora_adapter_ids):
# currenly we only load one adapter at a time but
# this can be extended to merge multiple adapters
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
if adapter_index in model.loaded_adapters:
continue
logger.info(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
)
weight_names = tuple([v[0] for v in target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
model.model_id,
adapter_parameters,
adapter_index,
weight_names,
False,
)
unused_weight_names = adapter_weight_names.copy()
adapter_layers = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
for layer_name in adapter_layers:
nlayers = (
1 if layer_name == "lm_head" else len(model.model.model.layers)
)
adapter_weights = LoraWeights.prepare_weights(
config=adapter_config,
module_map=module_map,
layer_type=layer_name,
unused_weight_names=unused_weight_names,
nlayers=nlayers,
dtype=model.dtype,
world_size=model.world_size,
process_group=model.process_group,
target_to_layer=target_to_layer,
)
if adapter_weights is None:
continue
model.layer_to_adapter_weights[layer_name].add_adapter(
adapter_index, adapter_weights
)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
model.loaded_adapters.add(adapter_index)
return model

View File

@ -1,7 +1,6 @@
import math import math
import os import os
import time import time
import itertools
import torch import torch
import torch.distributed import torch.distributed
@ -1695,72 +1694,3 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,85 +0,0 @@
import torch
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashMistral(FlashCausalLM):
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -4,20 +4,12 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.utils.adapter import (
load_and_merge_adapters,
AdapterParameters,
AdapterSource,
)
from text_generation_server.utils.log import log_master
from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
@ -61,7 +53,6 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights LayerAdapterWeights
) )
self.target_to_layer = None
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id self.static_adapter_id = adapter_id
@ -142,140 +133,3 @@ class Model(ABC):
raise RuntimeError( raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
) )
@property
def supports_adapter_loading(self) -> bool:
return False
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
return {}
@property
def adapter_layers(self) -> List[str]:
return []
@property
def default_traced_adapter_layers(self) -> List[str]:
return []
def get_num_layers_for_type(self, layer_type: str) -> int:
return 0
def is_row_parallel(self, layer_type: str) -> bool:
return False
@property
def max_speculative_tokens(self) -> int:
return max(
[
weights.max_speculative_tokens
for weights in self.layer_to_adapter_weights.values()
],
default=0,
)
def load_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
api_token: str,
dynamic: bool = True,
):
"""Loads adapter weights from disk / host memory on the GPU.
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if dynamic and not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
log_master(
logger.info,
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
)
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
self.model_id,
adapter_parameters,
adapter_source,
adapter_index,
weight_names,
api_token,
False,
)
unused_weight_names = adapter_weight_names.copy()
for layer_name in self.adapter_layers:
adapter_weights = adapter_config.load_batched_adapter_weights(
self,
module_map,
layer_name,
unused_weight_names,
dynamic,
)
if adapter_weights is None:
continue
layer_weights = self.layer_to_adapter_weights[layer_name]
layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0:
log_master(
logger.warning,
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
)
if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
self.loaded_adapters.add(adapter_index)
def offload_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
for layer_name in self.adapter_layers:
if layer_name in self.layer_to_adapter_weights:
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
self.loaded_adapters.remove(adapter_index)

View File

@ -30,9 +30,6 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, set_adapter_to_index from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.utils.adapter import (
AdapterParameters,
)
class SignalHandler: class SignalHandler:
@ -238,27 +235,7 @@ def serve(
dtype, dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) adapter_to_index,
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: improve non merged adapter loading and long term
# improve adapter loading as a whole
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
model.load_adapter(
adapter_parameters,
None, # adapter_source
adapter_index,
None, # api_token
False, # dynamic
) )
except Exception: except Exception:

View File

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Set, Tuple
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.utils.merges.strategies import merge_adapters from text_generation_server.utils.merges.strategies import merge_adapters
from text_generation_server.utils import hub from text_generation_server.utils import hub
@ -43,34 +42,25 @@ class AdapterSource:
def load_and_merge_adapters( def load_and_merge_adapters(
model_id: str, model_id: str,
adapter_parameters: AdapterParameters, adapter_parameters: AdapterParameters,
adapter_source: str,
adapter_index: int, adapter_index: int,
weight_names: Tuple[str], weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1: if len(adapter_parameters.adapter_ids) == 1:
return load_module_map( return load_module_map(
model_id, model_id,
adapter_parameters.adapter_ids[0], adapter_parameters.adapter_ids[0],
adapter_source,
weight_names, weight_names,
api_token,
trust_remote_code, trust_remote_code,
) )
adapter_params = AdapterParametersContainer( adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
adapter_parameters, adapter_source, adapter_index return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
)
return _load_and_merge(
model_id, adapter_params, weight_names, api_token, trust_remote_code
)
@dataclass @dataclass
class AdapterParametersContainer: class AdapterParametersContainer:
adapter_parameters: AdapterParameters adapter_parameters: AdapterParameters
adapter_source: str
adapter_index: int adapter_index: int
def __hash__(self) -> int: def __hash__(self) -> int:
@ -82,7 +72,6 @@ def _load_and_merge(
model_id: str, model_id: str,
adapter_params: AdapterParametersContainer, adapter_params: AdapterParametersContainer,
weight_names: Tuple[str], weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters params = adapter_params.adapter_parameters
@ -98,9 +87,7 @@ def _load_and_merge(
load_module_map( load_module_map(
model_id, model_id,
adapter_id, adapter_id,
adapter_params.adapter_source,
weight_names, weight_names,
api_token,
trust_remote_code, trust_remote_code,
) )
) )
@ -159,14 +146,12 @@ def check_architectures(
def load_module_map( def load_module_map(
model_id: str, model_id: str,
adapter_id: str, adapter_id: str,
adapter_source: str,
weight_names: Tuple[str], weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main" revision = "main"
adapter_config = LoraConfig.load(adapter_id, api_token) adapter_config = LoraConfig.load(adapter_id, None)
if adapter_config.base_model_name_or_path != model_id: if adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
@ -177,7 +162,6 @@ def load_module_map(
try: try:
adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path, adapter_config.config_path,
token=api_token,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
except Exception: except Exception:
@ -194,3 +178,70 @@ def load_module_map(
adapter_weights, weight_names adapter_weights, weight_names
) )
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
def get_attn_weights(i, layer):
qkv = layer.self_attn.query_key_value
weights = {}
for k in ["q", "k", "v"]:
key = (i, f"{k}_proj")
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
weights[key] = value
weights[(i, "o_proj")] = (
f"model.layers.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
return weights
def get_mlp_weights(i, layer):
weights = {}
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
gup = layer.mlp.gate_up_proj
for k in ["gate", "up", "down"]:
key = (i, f"{k}_proj")
value = (
f"model.layers.{i}.mlp.{k}_proj",
gup if k != "down" else layer.mlp.down_proj,
)
weights[key] = value
return weights
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
# containing the weight tensor path and the actual layer object. This mapping is needed
# for the lora adapter to know which weights to update when applying the adapter.
def build_layer_weight_lookup(model):
if hasattr(model, "language_model"):
m = model.language_model.model
elif hasattr(model, "text_model"):
m = model.text_model.model
else:
m = model.model
layer_weights = {}
for i, layer in enumerate(m.layers):
attn_weights = get_attn_weights(i, layer)
mlp_weights = get_mlp_weights(i, layer)
layer_weights.update(attn_weights)
layer_weights.update(mlp_weights)
lm_head = None
if hasattr(m, "lm_head"):
lm_head = m.lm_head
elif hasattr(model, "lm_head"):
lm_head = model.lm_head
if lm_head:
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
return layer_weights