From db3d8e6518733fe3ea643cadc507b29d0612d7b2 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 30 May 2024 19:16:15 +0000 Subject: [PATCH] feat: first draft load multiple lora --- benchmark/src/generation.rs | 1 + proto/generate.proto | 2 + router/client/src/v2/client.rs | 1 + router/src/infer/v2/queue.rs | 2 + router/src/lib.rs | 6 + router/src/server.rs | 2 + router/src/validation.rs | 3 + server/text_generation_server/cli.py | 9 ++ .../text_generation_server/models/__init__.py | 4 +- .../custom_modeling/flash_llama_modeling.py | 106 +++++++++++++++++- .../models/flash_causal_lm.py | 23 ++++ .../models/flash_llama.py | 6 +- server/text_generation_server/server.py | 12 +- .../text_generation_server/utils/weights.py | 22 ++++ 14 files changed, 191 insertions(+), 8 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b82d23ba4..5e739703f 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + adapter_id: None, }) .collect(); diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f2..cffaa719b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -107,6 +107,8 @@ message Request { bool prefill_logprobs = 6; /// Return most likely n tokens uint32 top_n_tokens = 7; + /// LORA adapter id + optional string adapter_id = 8; } message Batch { diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs index 9a2e6ac79..3e5d9d3ba 100644 --- a/router/client/src/v2/client.rs +++ b/router/client/src/v2/client.rs @@ -154,6 +154,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, + adapter_id: None, }); n_tokens += max_input_length; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 3725c03e6..9265b79a0 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -290,6 +290,7 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, + adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -429,6 +430,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c497..08d57873e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -298,6 +298,11 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub grammar: Option, + + /// Lora adapter id + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub adapter_id: Option, } fn default_max_new_tokens() -> Option { @@ -324,6 +329,7 @@ fn default_parameters() -> GenerateParameters { seed: None, top_n_tokens: None, grammar: None, + adapter_id: None, } } diff --git a/router/src/server.rs b/router/src/server.rs index 30479b0e4..fee8ebdf7 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -668,6 +668,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, + ..Default::default() }, }) .collect(); @@ -1092,6 +1093,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar: typed_grammar, + ..Default::default() }, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index bb9ad3184..e2bf5a5d6 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -202,6 +202,7 @@ impl Validation { decoder_input_details, top_n_tokens, grammar, + adapter_id, .. } = request.parameters; @@ -383,6 +384,7 @@ impl Validation { parameters, stopping_parameters, top_n_tokens, + adapter_id, }) } @@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, + pub adapter_id: Option, } #[derive(Error, Debug)] diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 430323bcd..b18deabc3 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -78,6 +78,14 @@ def serve( if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) + # TODO: determine if this api makes sense + lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) + + # split on comma and strip whitespace + lora_adapter_ids = ( + [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] + ) + # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value @@ -92,6 +100,7 @@ def serve( ) server.serve( model_id, + lora_adapter_ids, revision, sharded, quantize, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a61cb83b4..34082ac25 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,7 +6,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional +from typing import Optional, List from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -253,6 +253,7 @@ for data in ModelType: def get_model( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -595,6 +596,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0d06d1048..27703e14f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -88,9 +88,11 @@ def load_attention(config, prefix, weights): class FlashLlamaAttention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, + all_adapter_weights, ): super().__init__() self.num_heads = config.num_attention_heads @@ -122,6 +124,29 @@ class FlashLlamaAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.index = index + self.adapter_weights = {} + for adapter_id, adapter_weights in all_adapter_weights.items(): + filtered_keys = list( + filter( + lambda x: x.startswith( + f"base_model.model.model.layers.{index}.self_attn" + ), + adapter_weights.keys(), + ) + ) + self.adapter_weights[adapter_id] = { + key: torch.tensor( + adapter_weights[key], + device=weights.device, + dtype=weights.dtype, + ).T + for key in filtered_keys + } + + self.index_to_key = { + i: key for i, key in enumerate(self.adapter_weights.keys()) + } self.o_proj = TensorParallelRowLinear.load( config, @@ -134,6 +159,23 @@ class FlashLlamaAttention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + def get_adapter_weights(self, lora_index): + adapter_id = self.index_to_key[lora_index] + q_proj_lora_a = self.adapter_weights[adapter_id][ + f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight" + ] + q_proj_lora_b = self.adapter_weights[adapter_id][ + f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight" + ] + + v_proj_lora_a = self.adapter_weights[adapter_id][ + f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight" + ] + v_proj_lora_b = self.adapter_weights[adapter_id][ + f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight" + ] + return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b + def forward( self, hidden_states, @@ -145,6 +187,8 @@ class FlashLlamaAttention(torch.nn.Module): slots, input_lengths, max_s, + batch_lora_adapter_mask, + lora_indices, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -157,6 +201,40 @@ class FlashLlamaAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = ( + self.get_adapter_weights( + # TODO: dont just assume the first adapter + lora_indices[0].item() + ) + ) + + query_adapted = torch.matmul( + hidden_states, + torch.matmul( + q_proj_lora_a, + q_proj_lora_b, + ), + ) + + value_adapted = torch.matmul( + hidden_states, + torch.matmul( + v_proj_lora_a, + v_proj_lora_b, + ), + ) + + batch_size = query.size(0) + + # TODO: improve this to avoid unnecessary work + # mask across batch and within lora adapters + query[batch_lora_adapter_mask] += query_adapted.view( + batch_size, self.num_heads, self.head_size + )[batch_lora_adapter_mask] + kv[batch_lora_adapter_mask, 1] += value_adapted.view( + batch_size, self.num_key_value_heads, self.head_size + )[batch_lora_adapter_mask] + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) @@ -261,10 +339,14 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights, all_adapter_weights): super().__init__() self.self_attn = FlashLlamaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + all_adapter_weights=all_adapter_weights, ) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -289,6 +371,8 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + batch_lora_adapter_mask, + lora_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -303,6 +387,8 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + batch_lora_adapter_mask, + lora_indices, ) # faster post attention rms norm @@ -316,7 +402,7 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, all_adapter_weights): super().__init__() process_group = weights.process_group @@ -325,6 +411,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers = nn.ModuleList( [ FlashLlamaLayer( + index=layer_id, prefix=( f"model.layers.{layer_id}" if not prefix @@ -332,6 +419,7 @@ class FlashLlamaModel(torch.nn.Module): ), config=config, weights=weights, + all_adapter_weights=all_adapter_weights, ) for layer_id in range(config.num_hidden_layers) ] @@ -360,6 +448,8 @@ class FlashLlamaModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + batch_lora_adapter_mask: Optional[List[str]], + lora_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -382,6 +472,8 @@ class FlashLlamaModel(torch.nn.Module): slots, input_lengths, max_s, + batch_lora_adapter_mask, + lora_indices, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -390,7 +482,7 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, all_adapter_weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -399,7 +491,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: @@ -423,6 +515,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + batch_lora_adapter_mask: Optional[List[str]] = None, + lora_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -436,6 +530,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, + batch_lora_adapter_mask=batch_lora_adapter_mask, + lora_indices=lora_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d37106..285271133 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -811,6 +811,8 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device) + lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device) torch.cuda.synchronize() # Run once outside to warmup self.model.forward( @@ -824,6 +826,8 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, + batch_lora_adapter_mask=batch_lora_adapter_mask, + lora_indices=lora_indices, ) torch.cuda.synchronize() @@ -839,6 +843,8 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, + batch_lora_adapter_mask=batch_lora_adapter_mask, + lora_indices=lora_indices, ) self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["speculative_logits"] = speculative_logits @@ -966,6 +972,10 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + batch_lora_adapter_mask = torch.zeros( + seqlen, dtype=torch.bool, device=self.device + ) + lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -981,6 +991,8 @@ class FlashCausalLM(Model): max_s=seqlen, lm_head_indices=None, prefill_cache_indices=None, + batch_lora_adapter_mask=batch_lora_adapter_mask, + lora_indices=lora_indices, ) def forward( @@ -1051,6 +1063,15 @@ class FlashCausalLM(Model): else: cuda_graph = None + batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device) + lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device) + + for i, r in enumerate(batch.requests): + if r.adapter_id: + lora_index = int(r.adapter_id) + lora_indices[i] = lora_index + batch_lora_adapter_mask[i] = True + if cu_seqlen_prefill is not None or cuda_graph is None: logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -1063,6 +1084,8 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, + batch_lora_adapter_mask=batch_lora_adapter_mask, + lora_indices=lora_indices, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index c5cbd2b83..78b352760 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -1,3 +1,4 @@ +import os import torch import torch.distributed @@ -13,7 +14,9 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + hub, ) +from text_generation_server.utils.weights import load_adaptor_weights tracer = trace.get_tracer(__name__) @@ -29,6 +32,7 @@ class FlashLlama(FlashCausalLM): speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -71,7 +75,7 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) prefix = "" - model = FlashLlamaForCausalLM(prefix, config, weights) + model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 569b6925a..9a5e92263 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -192,6 +192,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -203,6 +204,7 @@ def serve( ): async def serve_inner( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -224,6 +226,7 @@ def serve( try: model = get_model( model_id, + lora_adapter_ids, revision, sharded, quantize, @@ -262,6 +265,13 @@ def serve( set_model_id(model_id) asyncio.run( serve_inner( - model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, ) ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4d5fcb254..ab18e0c74 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -2,6 +2,7 @@ import os from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError +from safetensors.torch import load_file import torch from loguru import logger from huggingface_hub import hf_hub_download @@ -9,6 +10,27 @@ import json from text_generation_server.utils.log import log_once +# TODO: improve how the weights are loaded +def load_adaptor_weights(model_id, local_path, extension=".safetensors"): + adapter_weights = {} + if local_path.exists() and local_path.is_dir(): + local_files = list(local_path.glob(f"*{extension}")) + if not local_files: + raise FileNotFoundError( + f"No local weights found in {model_id} with extension {extension}" + ) + for filename in local_files: + adapter_weights.update(load_file(filename)) + + # TODO: remove (no need to sort) + # sorted on the the layer number (index 4 in the key) + sorted_keys = sorted( + adapter_weights.keys(), + key=lambda x: int(x.split(".")[4]), + ) + return (adapter_weights, sorted_keys) + + class Weights: def __init__( self,