From 611225f01701f5ad3dd468d3ae94ee6afab636c8 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 7 Jun 2024 01:20:41 +0000 Subject: [PATCH] feat: support base model generation and refactors --- .../text_generation_server/adapters/lora.py | 12 +-- .../adapters/weights.py | 9 +-- server/text_generation_server/cli.py | 1 - .../custom_modeling/flash_llama_modeling.py | 18 +---- .../models/flash_llama.py | 48 +++++------- .../models/flash_mistral.py | 51 +++++-------- server/text_generation_server/server.py | 2 +- server/text_generation_server/utils/lora.py | 74 ------------------- 8 files changed, 50 insertions(+), 165 deletions(-) delete mode 100644 server/text_generation_server/utils/lora.py diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 458a22e1..fb07ca28 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -8,7 +8,6 @@ from torch.distributed import ProcessGroup from text_generation_server.adapters.config import AdapterConfig, ModuleMap -LORA = "lora" from text_generation_server.adapters.weights import ( AdapterBatchMetadata, AdapterWeights, @@ -246,7 +245,7 @@ class BatchLoraWeights(BatchAdapterWeights): @classmethod def key(cls) -> str: - return LORA + return "lora" @classmethod def load( @@ -279,9 +278,12 @@ class BatchLoraWeights(BatchAdapterWeights): } max_rank = max( - adapter_weights[idx].lora_a_r - for idx in segment_indices - if idx in adapter_weights + ( + adapter_weights[idx].lora_a_r + for idx in segment_indices + if idx in adapter_weights + ), + default=0, ) if prefill or max_rank > BGMV_MAX_RANK: diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 2ed08df5..50c072ca 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -1,4 +1,3 @@ -############# from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import dataclass @@ -7,10 +6,6 @@ from typing import Dict, List, Optional, Set, Type import torch -LORA = "lora" -LM_HEAD = "lm_head" - - @dataclass class AdapterBatchMetadata: # [batch_size] @@ -127,7 +122,7 @@ class AdapterBatchData: if v.is_empty(): continue data[k] = v.get_data( - meta, prefill, prefill_head_indices if k == LM_HEAD else None + meta, prefill, prefill_head_indices if k == "lm_head" else None ) return AdapterBatchData(meta=meta, data=data, prefill=prefill) @@ -135,7 +130,7 @@ class AdapterBatchData: # TODO(travis): refactor to be less coupled to lora implementation ranks = set() for layer_data in self.data.values(): - lora_data = layer_data.get(LORA) + lora_data = layer_data.get("lora") if lora_data is None: continue diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 45c2fab9..2c066c5c 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -78,7 +78,6 @@ def serve( if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) - # TODO: determine if this api makes sense lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) # split on comma and strip whitespace diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 436c2f53..45027b1b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -52,16 +52,6 @@ if SYSTEM == "rocm": except Exception as e: raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -# Constants -Q_PROJ = "q_proj" -K_PROJ = "k_proj" -V_PROJ = "v_proj" -O_PROJ = "o_proj" - -GATE_PROJ = "gate_proj" -UP_PROJ = "up_proj" -DOWN_PROJ = "down_proj" - def load_attention(config, prefix, weights, layer_id): # Only defined in granite. @@ -100,7 +90,7 @@ def load_attention(config, prefix, weights, layer_id): return TensorParallelMultiAdapterLinear.load( base_layer, layer_id, - [Q_PROJ, K_PROJ, V_PROJ], + ["q_proj", "k_proj", "v_proj"], sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, @@ -160,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module): self.o_proj = TensorParallelAdapterRowLinear.load( o_proj, index, - O_PROJ, + "o_proj", process_group=weights.process_group, ) @@ -268,7 +258,7 @@ class LlamaMLP(nn.Module): self.gate_up_proj = TensorParallelMultiAdapterLinear.load( gate_up_proj, index, - [GATE_PROJ, UP_PROJ], + ["gate_proj", "up_proj"], sizes=[ config.intermediate_size, config.intermediate_size, @@ -286,7 +276,7 @@ class LlamaMLP(nn.Module): self.down_proj = TensorParallelAdapterRowLinear.load( down_proj, index, - DOWN_PROJ, + "down_proj", process_group=weights.process_group, ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 1266f6de..327e4a6f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -20,31 +20,17 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.lora import LoraConfig -Q_PROJ = "q_proj" -K_PROJ = "k_proj" -V_PROJ = "v_proj" -O_PROJ = "o_proj" - -GATE_PROJ = "gate_proj" -UP_PROJ = "up_proj" -DOWN_PROJ = "down_proj" - -LM_HEAD = "lm_head" - - -# TODO(travis): re-enable LM_HEAD after resolving issues with outputs ADAPTER_LAYERS = [ - Q_PROJ, - K_PROJ, - V_PROJ, - O_PROJ, - GATE_PROJ, - UP_PROJ, - DOWN_PROJ, -] # LM_HEAD -ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} class FlashLlama(FlashCausalLM): @@ -123,32 +109,32 @@ class FlashLlama(FlashCausalLM): prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = ( + layer_weights[(i, "q_proj")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, K_PROJ)] = ( + layer_weights[(i, "k_proj")] = ( f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, V_PROJ)] = ( + layer_weights[(i, "v_proj")] = ( f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, O_PROJ)] = ( + layer_weights[(i, "o_proj")] = ( f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj, ) - layer_weights[(i, GATE_PROJ)] = ( + layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, UP_PROJ)] = ( + layer_weights[(i, "up_proj")] = ( f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, DOWN_PROJ)] = ( + layer_weights[(i, "down_proj")] = ( f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj, ) @@ -162,7 +148,7 @@ class FlashLlama(FlashCausalLM): @property def default_traced_adapter_layers(self) -> List[str]: - return [Q_PROJ, V_PROJ] + return ["q_proj", "v_proj"] def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 37cc0235..90a95c41 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -21,29 +21,16 @@ from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) -Q_PROJ = "q_proj" -K_PROJ = "k_proj" -V_PROJ = "v_proj" -O_PROJ = "o_proj" - -GATE_PROJ = "gate_proj" -UP_PROJ = "up_proj" -DOWN_PROJ = "down_proj" - -LM_HEAD = "lm_head" - - -# TODO(travis): re-enable LM_HEAD after resolving issues with outputs ADAPTER_LAYERS = [ - Q_PROJ, - K_PROJ, - V_PROJ, - O_PROJ, - GATE_PROJ, - UP_PROJ, - DOWN_PROJ, -] # LM_HEAD -ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} class BaseFlashMistral(FlashCausalLM): @@ -133,37 +120,37 @@ class BaseFlashMistral(FlashCausalLM): prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = ( + layer_weights[(i, "q_proj")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, K_PROJ)] = ( + layer_weights[(i, "k_proj")] = ( f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, V_PROJ)] = ( + layer_weights[(i, "v_proj")] = ( f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, O_PROJ)] = ( + layer_weights[(i, "o_proj")] = ( f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj, ) - layer_weights[(i, GATE_PROJ)] = ( + layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, UP_PROJ)] = ( + layer_weights[(i, "up_proj")] = ( f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, DOWN_PROJ)] = ( + layer_weights[(i, "down_proj")] = ( f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj, ) - layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head) return layer_weights @property @@ -172,10 +159,10 @@ class BaseFlashMistral(FlashCausalLM): @property def default_traced_adapter_layers(self) -> List[str]: - return [Q_PROJ, V_PROJ] + return ["q_proj", "v_proj"] def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + return 1 if layer_type == "lm_head" else len(self.model.model.layers) def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 36524c82..3bb0dae9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -251,7 +251,7 @@ def serve( density=1.0, majority_sign_method=0, ) - adapter_index = index + adapter_index = index + 1 adapter_to_index[adapter_id] = adapter_index model.load_adapter( adapter_parameters, diff --git a/server/text_generation_server/utils/lora.py b/server/text_generation_server/utils/lora.py deleted file mode 100644 index 8eed3a97..00000000 --- a/server/text_generation_server/utils/lora.py +++ /dev/null @@ -1,74 +0,0 @@ -import json -from text_generation_server.utils import ( - hub, -) -import os - - -class LoraConfig: - def __init__( - self, - alpha_pattern=None, - auto_mapping=None, - base_model_name_or_path="", - bias="none", - fan_in_fan_out=False, - inference_mode=True, - init_lora_weights=True, - layer_replication=None, - layers_pattern=None, - layers_to_transform=None, - loftq_config=None, - lora_alpha=16, - lora_dropout=0.1, - megatron_config=None, - megatron_core="megatron.core", - modules_to_save=None, - peft_type="LORA", - r=8, - rank_pattern=None, - revision=None, - target_modules=None, - task_type="CAUSAL_LM", - use_dora=False, - use_rslora=False, - config_path=None, - ): - self.alpha_pattern = alpha_pattern or {} - self.auto_mapping = auto_mapping - self.base_model_name_or_path = base_model_name_or_path - self.bias = bias - self.fan_in_fan_out = fan_in_fan_out - self.inference_mode = inference_mode - self.init_lora_weights = init_lora_weights - self.layer_replication = layer_replication - self.layers_pattern = layers_pattern - self.layers_to_transform = layers_to_transform - self.loftq_config = loftq_config or {} - self.lora_alpha = lora_alpha - self.lora_dropout = lora_dropout - self.megatron_config = megatron_config - self.megatron_core = megatron_core - self.modules_to_save = modules_to_save - self.peft_type = peft_type - self.r = r - self.rank_pattern = rank_pattern or {} - self.revision = revision - self.target_modules = target_modules or ["q_proj", "v_proj"] - self.task_type = task_type - self.use_dora = use_dora - self.use_rslora = use_rslora - self.config_path = config_path - - @classmethod - def from_file(cls, filename): - with open(filename, "r") as f: - json_data = json.load(f) - return cls(**json_data, config_path=filename) - - # TODO: support fetching the model from the hub if it's not in the cache - @classmethod - def from_pretrained(cls, adapter_id, revision=None): - d = hub._get_cached_revision_directory(adapter_id, revision) - filename = os.path.join(d, "adapter_config.json") - return cls.from_file(filename)