From 4f1543d3c75ced99406b8e0fce71371fc39daa43 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 19 Jun 2024 16:13:42 +0000 Subject: [PATCH] fix: refactors and adjust flash llama lora logic --- .../adapters/__init__.py | 24 +----- .../text_generation_server/adapters/config.py | 13 ++- .../text_generation_server/adapters/lora.py | 18 ++-- .../adapters/weights.py | 4 + .../custom_modeling/flash_llama_modeling.py | 86 +++++++++---------- .../text_generation_server/utils/adapter.py | 36 ++++---- server/text_generation_server/utils/peft.py | 6 +- .../text_generation_server/utils/segments.py | 4 + server/text_generation_server/utils/sgmv.py | 4 + 9 files changed, 101 insertions(+), 94 deletions(-) diff --git a/server/text_generation_server/adapters/__init__.py b/server/text_generation_server/adapters/__init__.py index 0e6f6e45..8697cb9e 100644 --- a/server/text_generation_server/adapters/__init__.py +++ b/server/text_generation_server/adapters/__init__.py @@ -1,31 +1,13 @@ -import json -from pathlib import Path -from typing import Dict, Optional +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/__init__.py +# License: Apache License Version 2.0, January 2004 -from text_generation_server.adapters.config import AdapterConfig -from text_generation_server.adapters.lora import LoraConfig from text_generation_server.adapters.weights import ( AdapterBatchData, AdapterBatchMetadata, ) - -def load_adapter_config( - config_path: Optional[Path], - adapter_config_path: Optional[Path], - api_token: str, -) -> AdapterConfig: - if adapter_config_path is not None and adapter_config_path.exists(): - return LoraConfig.load(str(adapter_config_path.parent), api_token) - - raise ValueError( - f"No valid adapter config file found: " - f"tried {adapter_config_path} and {config_path}" - ) - - __all__ = [ "AdapterBatchData", "AdapterBatchMetadata", - "load_adapter_config", ] diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py index 653c7bc8..5261d4b5 100644 --- a/server/text_generation_server/adapters/config.py +++ b/server/text_generation_server/adapters/config.py @@ -1,3 +1,7 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/config.py +# License: Apache License Version 2.0, January 2004 + from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple @@ -10,7 +14,10 @@ if TYPE_CHECKING: from text_generation_server.models.model import Model -ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] +@dataclass +class ModuleMap: + module_name: str + module_weights: Dict[str, Tuple[torch.Tensor, str]] @dataclass @@ -20,7 +27,7 @@ class AdapterConfig(ABC): @abstractmethod def map_weights_for_model( self, - adapter_weights: Dict, + adapter_weights: Dict[int, AdapterWeights], weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: pass @@ -29,7 +36,7 @@ class AdapterConfig(ABC): def load_batched_adapter_weights( self, model: "Model", - module_map: Dict[str, Dict], + module_map: ModuleMap, layer_type: str, unused_weight_names: Set[str], dynamic: bool, diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index fb07ca28..d176d150 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -1,3 +1,7 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/lora.py +# License: Apache License Version 2.0, January 2004 + from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union @@ -25,8 +29,6 @@ from text_generation_server.utils.sgmv import ( if TYPE_CHECKING: from text_generation_server.models.model import Model -EMPTY_TENSOR = torch.tensor([]) - @dataclass class LoraConfig(AdapterConfig): @@ -38,7 +40,7 @@ class LoraConfig(AdapterConfig): def map_weights_for_model( self, - adapter_weights: Dict, + adapter_weights: Dict[int, AdapterWeights], weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: adapter_weight_names = set() @@ -262,7 +264,7 @@ class BatchLoraWeights(BatchAdapterWeights): if not adapter_weights: return None - first_weights = list(adapter_weights.values())[0] + first_weights = next(iter(adapter_weights.values())) device = first_weights.weights_a.device segment_indices = meta.segment_indices @@ -293,7 +295,7 @@ class BatchLoraWeights(BatchAdapterWeights): ( adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights - else EMPTY_TENSOR.data_ptr() + else 0 ) for idx in segment_indices ], @@ -305,7 +307,7 @@ class BatchLoraWeights(BatchAdapterWeights): ( adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights - else EMPTY_TENSOR.data_ptr() + else 0 ) for idx in segment_indices ], @@ -319,7 +321,7 @@ class BatchLoraWeights(BatchAdapterWeights): ( adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights - else EMPTY_TENSOR.data_ptr() + else 0 ) for idx in segment_indices ], @@ -331,7 +333,7 @@ class BatchLoraWeights(BatchAdapterWeights): ( adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights - else EMPTY_TENSOR.data_ptr() + else 0 ) for idx in segment_indices ], diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 50c072ca..8f658756 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -1,3 +1,7 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/weights.py +# License: Apache License Version 2.0, January 2004 + from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import dataclass diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 45027b1b..06558379 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -59,7 +59,7 @@ def load_attention(config, prefix, weights, layer_id): # if specific model type, load the correct attention if config.model_type == "phi3": - base_layer = TensorParallelColumnLinear.load_qkv( + return TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.qkv_proj", weights=weights, @@ -68,7 +68,7 @@ def load_attention(config, prefix, weights, layer_id): num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": - base_layer = TensorParallelColumnLinear.load_qkv( + return TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.W_pack", weights=weights, @@ -76,28 +76,28 @@ def load_attention(config, prefix, weights, layer_id): num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + else: + # otherwise, load the default attention based on the number of heads + base_layer = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) - # otherwise, load the default attention based on the number of heads - base_layer = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=bias, - ) - - head_size = config.hidden_size // config.num_attention_heads - return TensorParallelMultiAdapterLinear.load( - base_layer, - layer_id, - ["q_proj", "k_proj", "v_proj"], - sizes=[ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, - ], - process_group=weights.process_group, - ) + head_size = config.hidden_size // config.num_attention_heads + return TensorParallelMultiAdapterLinear.load( + base_layer, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) class FlashLlamaAttention(torch.nn.Module): @@ -240,7 +240,7 @@ class LlamaMLP(nn.Module): # Fuse gate and up proj bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": - gate_up_proj = TensorParallelColumnLinear.load_gate_up( + self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, @@ -255,16 +255,16 @@ class LlamaMLP(nn.Module): bias=bias, ) - self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, - index, - ["gate_proj", "up_proj"], - sizes=[ - config.intermediate_size, - config.intermediate_size, - ], - process_group=weights.process_group, - ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) down_proj = TensorParallelRowLinear.load( config, @@ -273,12 +273,15 @@ class LlamaMLP(nn.Module): bias=bias, ) - self.down_proj = TensorParallelAdapterRowLinear.load( - down_proj, - index, - "down_proj", - process_group=weights.process_group, - ) + if config.model_type == "phi3": + self.down_proj = down_proj + else: + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() @@ -471,9 +474,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) - def get_lora_index(self, adapter_id): - return self.model.layers[0].self_attn.key_to_index[adapter_id] - def forward( self, input_ids: torch.Tensor, diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 1d3a6442..4e2492de 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -1,3 +1,7 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/adapter.py +# License: Apache License Version 2.0, January 2004 + import warnings from dataclasses import dataclass from functools import lru_cache @@ -20,22 +24,20 @@ if TYPE_CHECKING: BASE_MODEL_ADAPTER_ID = "__base_model__" +@dataclass class AdapterParameters: - def __init__( - self, adapter_ids, weights, merge_strategy, density, majority_sign_method - ): - self.adapter_ids = adapter_ids - self.weights = weights - self.merge_strategy = merge_strategy - self.density = density - self.majority_sign_method = majority_sign_method + adapter_ids: Tuple[str] + weights: Tuple[float] + merge_strategy: NotImplemented + density: float + majority_sign_method: NotImplemented +@dataclass class AdapterSource: - def __init__(self, adapter_id: str, model_id: str, revision: str): - self.adapter_id = adapter_id - self.model_id = model_id - self.revision = revision + adapter_id: str + model_id: str + revision: str def load_and_merge_adapters( @@ -65,11 +67,11 @@ def load_and_merge_adapters( ) +@dataclass class AdapterParametersContainer: - def __init__(self, adapter_parameters, adapter_source, adapter_index): - self.adapter_parameters = adapter_parameters - self.adapter_source = adapter_source - self.adapter_index = adapter_index + adapter_parameters: AdapterParameters + adapter_source: str + adapter_index: int def __hash__(self) -> int: return self.adapter_index @@ -123,7 +125,7 @@ def check_architectures( ): try: if not adapter_config.base_model_name_or_path: - # Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None) + # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None) return expected_config = AutoConfig.from_pretrained( diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 5aaeb5ac..0ea89267 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -1,5 +1,5 @@ import os -import json +from typing import Union from loguru import logger import torch @@ -45,7 +45,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): tokenizer.save_pretrained(cache_dir) -def download_peft(model_id, revision, trust_remote_code): +def download_peft( + model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool +): torch_dtype = torch.float16 try: _model = AutoPeftModelForCausalLM.from_pretrained( diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py index 0a50f20f..f5961102 100644 --- a/server/text_generation_server/utils/segments.py +++ b/server/text_generation_server/utils/segments.py @@ -1,3 +1,7 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/segments.py +# License: Apache License Version 2.0, January 2004 + from typing import List, Tuple, Union import torch diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py index 7ad6288d..e0aec25f 100644 --- a/server/text_generation_server/utils/sgmv.py +++ b/server/text_generation_server/utils/sgmv.py @@ -1,3 +1,7 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/sgmv.py +# License: Apache License Version 2.0, January 2004 + import os import warnings from functools import lru_cache