From 70dc958fb813ed508777db3f3064aacead5be76a Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 5 Jul 2024 15:00:36 +0000 Subject: [PATCH] fix: refactor adapter weight loading and mapping --- .../text_generation_server/adapters/config.py | 13 +- .../text_generation_server/adapters/lora.py | 56 +++---- .../adapters/weights.py | 7 - .../text_generation_server/models/__init__.py | 122 ++++++++++++++- .../models/flash_causal_lm.py | 70 --------- .../models/flash_mistral.py | 85 ---------- server/text_generation_server/models/model.py | 148 +----------------- server/text_generation_server/server.py | 25 +-- .../text_generation_server/utils/adapter.py | 89 ++++++++--- 9 files changed, 220 insertions(+), 395 deletions(-) delete mode 100644 server/text_generation_server/models/flash_mistral.py diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py index 5261d4b5..2ee53b12 100644 --- a/server/text_generation_server/adapters/config.py +++ b/server/text_generation_server/adapters/config.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Set, Tuple import torch @@ -31,14 +31,3 @@ class AdapterConfig(ABC): weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: pass - - @abstractmethod - def load_batched_adapter_weights( - self, - model: "Model", - module_map: ModuleMap, - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 87543be2..d6f15465 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig): adapter_weight_names.add(lora_b_name) return module_map, adapter_weight_names - def load_batched_adapter_weights( - self, - model: "Model", - module_map: Dict[str, Dict], - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - return LoraWeights.load( - self, - model, - module_map, - layer_type, - unused_weight_names, - ) - @classmethod def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) @@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights): def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: return [BatchLoraWeights] + # prepare pre-loaded lora weights for use in the model. + # + # this method processes and organizes lora weights for a specific layer type across all layers: + # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor. + # - retrieves weights from `module_map` based on the `layer_type`. + # - processes `nlayers` number of layers. + # - converts weights to the specified `dtype`. + # - shards weights across `world_size` number of processes using the `process_group`. + # - maps weights to specific layers using `target_to_layer`. + # - tracks `unused_weight_names` to identify any unused weights. + # + # the method handles weight transposition, scaling, and padding to ensure compatibility + # with SGMV or BGMV operations. @classmethod - def load( + def prepare_weights( cls, config: LoraConfig, - model: "Model", module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], + nlayers: int, + dtype: torch.dtype, + world_size: int, + process_group: ProcessGroup, + target_to_layer: Dict[str, Tuple[str, torch.Tensor]], ) -> Optional[AdapterWeights]: - nlayers = model.get_num_layers_for_type(layer_type) lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers for layer_id in range(nlayers): key = (layer_id, layer_type) - weight_name, layer = model.target_to_layer[key] + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device @@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights): return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, model.dtype) + lora_a = lora_a.to(base_device, dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, model.dtype) + lora_b = lora_b.to(base_device, dtype) scale = get_scaling_factor( config.lora_alpha, @@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [ - pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list - ] - lora_b_list = [ - pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list - ] + lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] + lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] if lora_a_list: # update rank if it was padded @@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights): *shard_lora_weights( weights_a=lora_a_list, weights_b=lora_b_list, - split_dim=0 if model.is_row_parallel(layer_type) else 1, - process_group=model.process_group, + split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1, + process_group=process_group, ), config, ) diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 8f658756..2cc4ba6d 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -71,13 +71,6 @@ class LayerAdapterWeights: return del self.adapter_weights[adapter_idx] - @property - def max_speculative_tokens(self) -> int: - return max( - adapter_weights.speculative_tokens - for adapter_weights in self.adapter_weights.values() - ) - def is_empty(self) -> bool: return len(self.adapter_weights) == 0 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a43cdfed..c0366623 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -33,6 +33,15 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( T5ForConditionalGeneration, ) + +from text_generation_server.utils.adapter import ( + AdapterParameters, + build_layer_weight_lookup, + load_and_merge_adapters, +) +from text_generation_server.adapters.lora import LoraWeights + + from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_master @@ -294,7 +303,7 @@ for data in ModelType: __GLOBALS[data.name] = data.value["type"] -def get_model( +def _get_model( model_id: str, lora_adapter_ids: Optional[List[str]], revision: Optional[str], @@ -1110,3 +1119,114 @@ def get_model( ) raise ValueError(f"Unsupported model type {model_type}") + + +# get_model wraps the internal _get_model function and adds support for loading adapters +# this provides a post model loading hook to load adapters into the model after the model has been loaded +def get_model( + model_id: str, + lora_adapter_ids: Optional[List[str]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[str], + trust_remote_code: bool, + max_input_tokens: int, + adapter_to_index: dict[str, int], +): + model = _get_model( + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + max_input_tokens, + ) + + if len(lora_adapter_ids) > 0: + target_to_layer = build_layer_weight_lookup(model.model) + + for index, adapter_id in enumerate(lora_adapter_ids): + # currenly we only load one adapter at a time but + # this can be extended to merge multiple adapters + adapter_parameters = AdapterParameters( + adapter_ids=[adapter_id], + weights=None, # will be set to 1 + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + + adapter_index = index + 1 + adapter_to_index[adapter_id] = adapter_index + + if adapter_index in model.loaded_adapters: + continue + + logger.info( + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + ) + weight_names = tuple([v[0] for v in target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + model.model_id, + adapter_parameters, + adapter_index, + weight_names, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + + adapter_layers = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + + for layer_name in adapter_layers: + nlayers = ( + 1 if layer_name == "lm_head" else len(model.model.model.layers) + ) + adapter_weights = LoraWeights.prepare_weights( + config=adapter_config, + module_map=module_map, + layer_type=layer_name, + unused_weight_names=unused_weight_names, + nlayers=nlayers, + dtype=model.dtype, + world_size=model.world_size, + process_group=model.process_group, + target_to_layer=target_to_layer, + ) + + if adapter_weights is None: + continue + + model.layer_to_adapter_weights[layer_name].add_adapter( + adapter_index, adapter_weights + ) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + model.loaded_adapters.add(adapter_index) + + return model diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cfffafa1..7b31f507 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,7 +1,6 @@ import math import os import time -import itertools import torch import torch.distributed @@ -1695,72 +1694,3 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py deleted file mode 100644 index 2b2bd2e0..00000000 --- a/server/text_generation_server/models/flash_mistral.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM - - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashMistral(FlashCausalLM): - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index e7748bb9..159139de 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -4,20 +4,12 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from collections import defaultdict -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights -from text_generation_server.utils.adapter import ( - load_and_merge_adapters, - AdapterParameters, - AdapterSource, -) -from text_generation_server.utils.log import log_master -from loguru import logger - BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -61,7 +53,6 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -142,140 +133,3 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) - - @property - def supports_adapter_loading(self) -> bool: - return False - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - return {} - - @property - def adapter_layers(self) -> List[str]: - return [] - - @property - def default_traced_adapter_layers(self) -> List[str]: - return [] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 0 - - def is_row_parallel(self, layer_type: str) -> bool: - return False - - @property - def max_speculative_tokens(self) -> int: - return max( - [ - weights.max_speculative_tokens - for weights in self.layer_to_adapter_weights.values() - ], - default=0, - ) - - def load_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - api_token: str, - dynamic: bool = True, - ): - """Loads adapter weights from disk / host memory on the GPU. - - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are applied during the forward - pass and stored separately from the base model parameters. - """ - if self.target_to_layer is None: - self.target_to_layer = self.adapter_target_to_layer() - if adapter_index in self.loaded_adapters: - # Adapter already loaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if dynamic and not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - log_master( - logger.info, - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", - ) - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - ( - module_map, - adapter_config, - adapter_weight_names, - adapter_tokenizer, - ) = load_and_merge_adapters( - self.model_id, - adapter_parameters, - adapter_source, - adapter_index, - weight_names, - api_token, - False, - ) - - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - adapter_weights = adapter_config.load_batched_adapter_weights( - self, - module_map, - layer_name, - unused_weight_names, - dynamic, - ) - - if adapter_weights is None: - continue - - layer_weights = self.layer_to_adapter_weights[layer_name] - layer_weights.add_adapter(adapter_index, adapter_weights) - - if len(unused_weight_names) > 0: - log_master( - logger.warning, - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", - ) - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.loaded_adapters.add(adapter_index) - - def offload_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - ): - """Offloads the adapter weights from GPU to CPU or disk.""" - if adapter_index not in self.loaded_adapters: - # Adapter already offloaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - for layer_name in self.adapter_layers: - if layer_name in self.layer_to_adapter_weights: - self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) - - self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index aee287c6..7455740a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,9 +30,6 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id, set_adapter_to_index -from text_generation_server.utils.adapter import ( - AdapterParameters, -) class SignalHandler: @@ -238,29 +235,9 @@ def serve( dtype, trust_remote_code, max_input_tokens, + adapter_to_index, ) - if len(lora_adapter_ids) > 0: - for index, adapter_id in enumerate(lora_adapter_ids): - # TODO: improve non merged adapter loading and long term - # improve adapter loading as a whole - adapter_parameters = AdapterParameters( - adapter_ids=[adapter_id], - weights=None, # will be set to 1 - merge_strategy=0, - density=1.0, - majority_sign_method=0, - ) - adapter_index = index + 1 - adapter_to_index[adapter_id] = adapter_index - model.load_adapter( - adapter_parameters, - None, # adapter_source - adapter_index, - None, # api_token - False, # dynamic - ) - except Exception: logger.exception("Error when initializing model") raise diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 4e2492de..21f7bbbc 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Set, Tuple from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -from text_generation_server.pb import generate_pb2 from text_generation_server.utils.merges.strategies import merge_adapters from text_generation_server.utils import hub @@ -43,34 +42,25 @@ class AdapterSource: def load_and_merge_adapters( model_id: str, adapter_parameters: AdapterParameters, - adapter_source: str, adapter_index: int, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: if len(adapter_parameters.adapter_ids) == 1: return load_module_map( model_id, adapter_parameters.adapter_ids[0], - adapter_source, weight_names, - api_token, trust_remote_code, ) - adapter_params = AdapterParametersContainer( - adapter_parameters, adapter_source, adapter_index - ) - return _load_and_merge( - model_id, adapter_params, weight_names, api_token, trust_remote_code - ) + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) + return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) @dataclass class AdapterParametersContainer: adapter_parameters: AdapterParameters - adapter_source: str adapter_index: int def __hash__(self) -> int: @@ -82,7 +72,6 @@ def _load_and_merge( model_id: str, adapter_params: AdapterParametersContainer, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: params = adapter_params.adapter_parameters @@ -98,9 +87,7 @@ def _load_and_merge( load_module_map( model_id, adapter_id, - adapter_params.adapter_source, weight_names, - api_token, trust_remote_code, ) ) @@ -159,14 +146,12 @@ def check_architectures( def load_module_map( model_id: str, adapter_id: str, - adapter_source: str, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: revision = "main" - adapter_config = LoraConfig.load(adapter_id, api_token) + adapter_config = LoraConfig.load(adapter_id, None) if adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) @@ -177,7 +162,6 @@ def load_module_map( try: adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_config.config_path, - token=api_token, trust_remote_code=trust_remote_code, ) except Exception: @@ -194,3 +178,70 @@ def load_module_map( adapter_weights, weight_names ) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer + + +def get_attn_weights(i, layer): + qkv = layer.self_attn.query_key_value + weights = {} + + for k in ["q", "k", "v"]: + key = (i, f"{k}_proj") + value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) + weights[key] = value + + weights[(i, "o_proj")] = ( + f"model.layers.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + return weights + + +def get_mlp_weights(i, layer): + weights = {} + + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): + gup = layer.mlp.gate_up_proj + + for k in ["gate", "up", "down"]: + key = (i, f"{k}_proj") + value = ( + f"model.layers.{i}.mlp.{k}_proj", + gup if k != "down" else layer.mlp.down_proj, + ) + weights[key] = value + + return weights + + +# build_layer_weight_lookup creates a mapping of model layers to their corresponding +# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples +# containing the weight tensor path and the actual layer object. This mapping is needed +# for the lora adapter to know which weights to update when applying the adapter. +def build_layer_weight_lookup(model): + if hasattr(model, "language_model"): + m = model.language_model.model + elif hasattr(model, "text_model"): + m = model.text_model.model + else: + m = model.model + + layer_weights = {} + + for i, layer in enumerate(m.layers): + attn_weights = get_attn_weights(i, layer) + mlp_weights = get_mlp_weights(i, layer) + + layer_weights.update(attn_weights) + layer_weights.update(mlp_weights) + + lm_head = None + if hasattr(m, "lm_head"): + lm_head = m.lm_head + elif hasattr(model, "lm_head"): + lm_head = model.lm_head + + if lm_head: + layer_weights[(0, "lm_head")] = ("lm_head", lm_head) + + return layer_weights