feat: first draft load multiple lora

This commit is contained in:
drbh 2024-05-30 19:16:15 +00:00
parent 85dfc39222
commit db3d8e6518
14 changed files with 191 additions and 8 deletions

View File

@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
})
.collect();

View File

@ -107,6 +107,8 @@ message Request {
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// LORA adapter id
optional string adapter_id = 8;
}
message Batch {

View File

@ -154,6 +154,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;

View File

@ -290,6 +290,7 @@ impl State {
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -429,6 +430,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -298,6 +298,11 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>,
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
}
fn default_max_new_tokens() -> Option<u32> {
@ -324,6 +329,7 @@ fn default_parameters() -> GenerateParameters {
seed: None,
top_n_tokens: None,
grammar: None,
adapter_id: None,
}
}

View File

@ -668,6 +668,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
},
})
.collect();
@ -1092,6 +1093,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar: typed_grammar,
..Default::default()
},
};

View File

@ -202,6 +202,7 @@ impl Validation {
decoder_input_details,
top_n_tokens,
grammar,
adapter_id,
..
} = request.parameters;
@ -383,6 +384,7 @@ impl Validation {
parameters,
stopping_parameters,
top_n_tokens,
adapter_id,
})
}
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
pub adapter_id: Option<String>,
}
#[derive(Error, Debug)]

View File

@ -78,6 +78,14 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
# TODO: determine if this api makes sense
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
@ -92,6 +100,7 @@ def serve(
)
server.serve(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from typing import Optional, List
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -253,6 +253,7 @@ for data in ModelType:
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -595,6 +596,7 @@ def get_model(
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))

View File

@ -88,9 +88,11 @@ def load_attention(config, prefix, weights):
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
all_adapter_weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -122,6 +124,29 @@ class FlashLlamaAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.index = index
self.adapter_weights = {}
for adapter_id, adapter_weights in all_adapter_weights.items():
filtered_keys = list(
filter(
lambda x: x.startswith(
f"base_model.model.model.layers.{index}.self_attn"
),
adapter_weights.keys(),
)
)
self.adapter_weights[adapter_id] = {
key: torch.tensor(
adapter_weights[key],
device=weights.device,
dtype=weights.dtype,
).T
for key in filtered_keys
}
self.index_to_key = {
i: key for i, key in enumerate(self.adapter_weights.keys())
}
self.o_proj = TensorParallelRowLinear.load(
config,
@ -134,6 +159,23 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def get_adapter_weights(self, lora_index):
adapter_id = self.index_to_key[lora_index]
q_proj_lora_a = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
]
q_proj_lora_b = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
]
v_proj_lora_a = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
]
v_proj_lora_b = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
]
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
def forward(
self,
hidden_states,
@ -145,6 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -157,6 +201,40 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
self.get_adapter_weights(
# TODO: dont just assume the first adapter
lora_indices[0].item()
)
)
query_adapted = torch.matmul(
hidden_states,
torch.matmul(
q_proj_lora_a,
q_proj_lora_b,
),
)
value_adapted = torch.matmul(
hidden_states,
torch.matmul(
v_proj_lora_a,
v_proj_lora_b,
),
)
batch_size = query.size(0)
# TODO: improve this to avoid unnecessary work
# mask across batch and within lora adapters
query[batch_lora_adapter_mask] += query_adapted.view(
batch_size, self.num_heads, self.head_size
)[batch_lora_adapter_mask]
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
batch_size, self.num_key_value_heads, self.head_size
)[batch_lora_adapter_mask]
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
@ -261,10 +339,14 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, index, prefix, config, weights, all_adapter_weights):
super().__init__()
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
all_adapter_weights=all_adapter_weights,
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -289,6 +371,8 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -303,6 +387,8 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
)
# faster post attention rms norm
@ -316,7 +402,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, all_adapter_weights):
super().__init__()
process_group = weights.process_group
@ -325,6 +411,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
@ -332,6 +419,7 @@ class FlashLlamaModel(torch.nn.Module):
),
config=config,
weights=weights,
all_adapter_weights=all_adapter_weights,
)
for layer_id in range(config.num_hidden_layers)
]
@ -360,6 +448,8 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
batch_lora_adapter_mask: Optional[List[str]],
lora_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -382,6 +472,8 @@ class FlashLlamaModel(torch.nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -390,7 +482,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, all_adapter_weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
@ -399,7 +491,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
@ -423,6 +515,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
batch_lora_adapter_mask: Optional[List[str]] = None,
lora_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -436,6 +530,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -811,6 +811,8 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
@ -824,6 +826,8 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
torch.cuda.synchronize()
@ -839,6 +843,8 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
@ -966,6 +972,10 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
batch_lora_adapter_mask = torch.zeros(
seqlen, dtype=torch.bool, device=self.device
)
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -981,6 +991,8 @@ class FlashCausalLM(Model):
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
def forward(
@ -1051,6 +1063,15 @@ class FlashCausalLM(Model):
else:
cuda_graph = None
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
for i, r in enumerate(batch.requests):
if r.adapter_id:
lora_index = int(r.adapter_id)
lora_indices[i] = lora_index
batch_lora_adapter_mask[i] = True
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -1063,6 +1084,8 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None

View File

@ -1,3 +1,4 @@
import os
import torch
import torch.distributed
@ -13,7 +14,9 @@ from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
from text_generation_server.utils.weights import load_adaptor_weights
tracer = trace.get_tracer(__name__)
@ -29,6 +32,7 @@ class FlashLlama(FlashCausalLM):
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -71,7 +75,7 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,

View File

@ -192,6 +192,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -203,6 +204,7 @@ def serve(
):
async def serve_inner(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
@ -224,6 +226,7 @@ def serve(
try:
model = get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
@ -262,6 +265,13 @@ def serve(
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
)
)

View File

@ -2,6 +2,7 @@ import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
from safetensors.torch import load_file
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
@ -9,6 +10,27 @@ import json
from text_generation_server.utils.log import log_once
# TODO: improve how the weights are loaded
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
adapter_weights = {}
if local_path.exists() and local_path.is_dir():
local_files = list(local_path.glob(f"*{extension}"))
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"
)
for filename in local_files:
adapter_weights.update(load_file(filename))
# TODO: remove (no need to sort)
# sorted on the the layer number (index 4 in the key)
sorted_keys = sorted(
adapter_weights.keys(),
key=lambda x: int(x.split(".")[4]),
)
return (adapter_weights, sorted_keys)
class Weights:
def __init__(
self,