feat: first draft load multiple lora

This commit is contained in:
drbh 2024-05-30 19:16:15 +00:00
parent 85dfc39222
commit db3d8e6518
14 changed files with 191 additions and 8 deletions

View File

@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
adapter_id: None,
}) })
.collect(); .collect();

View File

@ -107,6 +107,8 @@ message Request {
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// LORA adapter id
optional string adapter_id = 8;
} }
message Batch { message Batch {

View File

@ -154,6 +154,7 @@ impl Client {
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,
adapter_id: None,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;

View File

@ -290,6 +290,7 @@ impl State {
entry.request.stopping_parameters.clone(), entry.request.stopping_parameters.clone(),
)), )),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -429,6 +430,7 @@ mod tests {
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0, top_n_tokens: 0,
adapter_id: None,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),

View File

@ -298,6 +298,11 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>, pub grammar: Option<GrammarType>,
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
@ -324,6 +329,7 @@ fn default_parameters() -> GenerateParameters {
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_id: None,
} }
} }

View File

@ -668,6 +668,7 @@ async fn completions(
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
..Default::default()
}, },
}) })
.collect(); .collect();
@ -1092,6 +1093,7 @@ async fn chat_completions(
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: typed_grammar, grammar: typed_grammar,
..Default::default()
}, },
}; };

View File

@ -202,6 +202,7 @@ impl Validation {
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
grammar, grammar,
adapter_id,
.. ..
} = request.parameters; } = request.parameters;
@ -383,6 +384,7 @@ impl Validation {
parameters, parameters,
stopping_parameters, stopping_parameters,
top_n_tokens, top_n_tokens,
adapter_id,
}) })
} }
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters, pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters, pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32, pub top_n_tokens: u32,
pub adapter_id: Option<String>,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]

View File

@ -78,6 +78,14 @@ def serve(
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
# TODO: determine if this api makes sense
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
@ -92,6 +100,7 @@ def serve(
) )
server.serve( server.serve(
model_id, model_id,
lora_adapter_ids,
revision, revision,
sharded, sharded,
quantize, quantize,

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional, List
from pathlib import Path from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -253,6 +253,7 @@ for data in ModelType:
def get_model( def get_model(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
@ -595,6 +596,7 @@ def get_model(
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))

View File

@ -88,9 +88,11 @@ def load_attention(config, prefix, weights):
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
index: int,
prefix: str, prefix: str,
config, config,
weights, weights,
all_adapter_weights,
): ):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -122,6 +124,29 @@ class FlashLlamaAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.index = index
self.adapter_weights = {}
for adapter_id, adapter_weights in all_adapter_weights.items():
filtered_keys = list(
filter(
lambda x: x.startswith(
f"base_model.model.model.layers.{index}.self_attn"
),
adapter_weights.keys(),
)
)
self.adapter_weights[adapter_id] = {
key: torch.tensor(
adapter_weights[key],
device=weights.device,
dtype=weights.dtype,
).T
for key in filtered_keys
}
self.index_to_key = {
i: key for i, key in enumerate(self.adapter_weights.keys())
}
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -134,6 +159,23 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
def get_adapter_weights(self, lora_index):
adapter_id = self.index_to_key[lora_index]
q_proj_lora_a = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
]
q_proj_lora_b = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
]
v_proj_lora_a = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
]
v_proj_lora_b = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
]
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -145,6 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -157,6 +201,40 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
self.get_adapter_weights(
# TODO: dont just assume the first adapter
lora_indices[0].item()
)
)
query_adapted = torch.matmul(
hidden_states,
torch.matmul(
q_proj_lora_a,
q_proj_lora_b,
),
)
value_adapted = torch.matmul(
hidden_states,
torch.matmul(
v_proj_lora_a,
v_proj_lora_b,
),
)
batch_size = query.size(0)
# TODO: improve this to avoid unnecessary work
# mask across batch and within lora adapters
query[batch_lora_adapter_mask] += query_adapted.view(
batch_size, self.num_heads, self.head_size
)[batch_lora_adapter_mask]
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
batch_size, self.num_key_value_heads, self.head_size
)[batch_lora_adapter_mask]
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
@ -261,10 +339,14 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, index, prefix, config, weights, all_adapter_weights):
super().__init__() super().__init__()
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
all_adapter_weights=all_adapter_weights,
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -289,6 +371,8 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -303,6 +387,8 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
) )
# faster post attention rms norm # faster post attention rms norm
@ -316,7 +402,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, all_adapter_weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -325,6 +411,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id,
prefix=( prefix=(
f"model.layers.{layer_id}" f"model.layers.{layer_id}"
if not prefix if not prefix
@ -332,6 +419,7 @@ class FlashLlamaModel(torch.nn.Module):
), ),
config=config, config=config,
weights=weights, weights=weights,
all_adapter_weights=all_adapter_weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
@ -360,6 +448,8 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
batch_lora_adapter_mask: Optional[List[str]],
lora_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -382,6 +472,8 @@ class FlashLlamaModel(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -390,7 +482,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, all_adapter_weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
@ -399,7 +491,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
), ),
weights=weights, weights=weights,
) )
self.model = FlashLlamaModel(prefix, config, weights) self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
if config.tie_word_embeddings: if config.tie_word_embeddings:
suffix = "model.embed_tokens" suffix = "model.embed_tokens"
else: else:
@ -423,6 +515,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
batch_lora_adapter_mask: Optional[List[str]] = None,
lora_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
@ -436,6 +530,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -811,6 +811,8 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
self.model.forward( self.model.forward(
@ -824,6 +826,8 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -839,6 +843,8 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
@ -966,6 +972,10 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
batch_lora_adapter_mask = torch.zeros(
seqlen, dtype=torch.bool, device=self.device
)
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -981,6 +991,8 @@ class FlashCausalLM(Model):
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,
prefill_cache_indices=None, prefill_cache_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
def forward( def forward(
@ -1051,6 +1063,15 @@ class FlashCausalLM(Model):
else: else:
cuda_graph = None cuda_graph = None
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
for i, r in enumerate(batch.requests):
if r.adapter_id:
lora_index = int(r.adapter_id)
lora_indices[i] = lora_index
batch_lora_adapter_mask[i] = True
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -1063,6 +1084,8 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None

View File

@ -1,3 +1,4 @@
import os
import torch import torch
import torch.distributed import torch.distributed
@ -13,7 +14,9 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
hub,
) )
from text_generation_server.utils.weights import load_adaptor_weights
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -29,6 +32,7 @@ class FlashLlama(FlashCausalLM):
speculator: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -71,7 +75,7 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights) model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,

View File

@ -192,6 +192,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
@ -203,6 +204,7 @@ def serve(
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
@ -224,6 +226,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, model_id,
lora_adapter_ids,
revision, revision,
sharded, sharded,
quantize, quantize,
@ -262,6 +265,13 @@ def serve(
set_model_id(model_id) set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
) )
) )

View File

@ -2,6 +2,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
from safetensors.torch import load_file
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -9,6 +10,27 @@ import json
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
# TODO: improve how the weights are loaded
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
adapter_weights = {}
if local_path.exists() and local_path.is_dir():
local_files = list(local_path.glob(f"*{extension}"))
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"
)
for filename in local_files:
adapter_weights.update(load_file(filename))
# TODO: remove (no need to sort)
# sorted on the the layer number (index 4 in the key)
sorted_keys = sorted(
adapter_weights.keys(),
key=lambda x: int(x.split(".")[4]),
)
return (adapter_weights, sorted_keys)
class Weights: class Weights:
def __init__( def __init__(
self, self,