mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
feat: first draft load multiple lora
This commit is contained in:
parent
85dfc39222
commit
db3d8e6518
@ -157,6 +157,7 @@ async fn prefill(
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -107,6 +107,8 @@ message Request {
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
/// LORA adapter id
|
||||
optional string adapter_id = 8;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -154,6 +154,7 @@ impl Client {
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
|
@ -290,6 +290,7 @@ impl State {
|
||||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
@ -429,6 +430,7 @@ mod tests {
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
@ -298,6 +298,11 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub grammar: Option<GrammarType>,
|
||||
|
||||
/// Lora adapter id
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
@ -324,6 +329,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -668,6 +668,7 @@ async fn completions(
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
..Default::default()
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
@ -1092,6 +1093,7 @@ async fn chat_completions(
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar: typed_grammar,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -202,6 +202,7 @@ impl Validation {
|
||||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
grammar,
|
||||
adapter_id,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -383,6 +384,7 @@ impl Validation {
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens,
|
||||
adapter_id,
|
||||
})
|
||||
}
|
||||
|
||||
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
pub top_n_tokens: u32,
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -78,6 +78,14 @@ def serve(
|
||||
if otlp_endpoint is not None:
|
||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||
|
||||
# TODO: determine if this api makes sense
|
||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||
|
||||
# split on comma and strip whitespace
|
||||
lora_adapter_ids = (
|
||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||
)
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
@ -92,6 +100,7 @@ def serve(
|
||||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
@ -253,6 +253,7 @@ for data in ModelType:
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
@ -595,6 +596,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
|
@ -88,9 +88,11 @@ def load_attention(config, prefix, weights):
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
all_adapter_weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -122,6 +124,29 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.index = index
|
||||
self.adapter_weights = {}
|
||||
for adapter_id, adapter_weights in all_adapter_weights.items():
|
||||
filtered_keys = list(
|
||||
filter(
|
||||
lambda x: x.startswith(
|
||||
f"base_model.model.model.layers.{index}.self_attn"
|
||||
),
|
||||
adapter_weights.keys(),
|
||||
)
|
||||
)
|
||||
self.adapter_weights[adapter_id] = {
|
||||
key: torch.tensor(
|
||||
adapter_weights[key],
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
).T
|
||||
for key in filtered_keys
|
||||
}
|
||||
|
||||
self.index_to_key = {
|
||||
i: key for i, key in enumerate(self.adapter_weights.keys())
|
||||
}
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -134,6 +159,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def get_adapter_weights(self, lora_index):
|
||||
adapter_id = self.index_to_key[lora_index]
|
||||
q_proj_lora_a = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
|
||||
]
|
||||
q_proj_lora_b = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
|
||||
]
|
||||
|
||||
v_proj_lora_a = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
|
||||
]
|
||||
v_proj_lora_b = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
|
||||
]
|
||||
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -145,6 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
@ -157,6 +201,40 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
|
||||
self.get_adapter_weights(
|
||||
# TODO: dont just assume the first adapter
|
||||
lora_indices[0].item()
|
||||
)
|
||||
)
|
||||
|
||||
query_adapted = torch.matmul(
|
||||
hidden_states,
|
||||
torch.matmul(
|
||||
q_proj_lora_a,
|
||||
q_proj_lora_b,
|
||||
),
|
||||
)
|
||||
|
||||
value_adapted = torch.matmul(
|
||||
hidden_states,
|
||||
torch.matmul(
|
||||
v_proj_lora_a,
|
||||
v_proj_lora_b,
|
||||
),
|
||||
)
|
||||
|
||||
batch_size = query.size(0)
|
||||
|
||||
# TODO: improve this to avoid unnecessary work
|
||||
# mask across batch and within lora adapters
|
||||
query[batch_lora_adapter_mask] += query_adapted.view(
|
||||
batch_size, self.num_heads, self.head_size
|
||||
)[batch_lora_adapter_mask]
|
||||
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
|
||||
batch_size, self.num_key_value_heads, self.head_size
|
||||
)[batch_lora_adapter_mask]
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
@ -261,10 +339,14 @@ class LlamaMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, index, prefix, config, weights, all_adapter_weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
all_adapter_weights=all_adapter_weights,
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -289,6 +371,8 @@ class FlashLlamaLayer(nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -303,6 +387,8 @@ class FlashLlamaLayer(nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
@ -316,7 +402,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, all_adapter_weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
@ -325,6 +411,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
@ -332,6 +419,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
all_adapter_weights=all_adapter_weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -360,6 +448,8 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
batch_lora_adapter_mask: Optional[List[str]],
|
||||
lora_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -382,6 +472,8 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -390,7 +482,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, all_adapter_weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
@ -399,7 +491,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
@ -423,6 +515,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
batch_lora_adapter_mask: Optional[List[str]] = None,
|
||||
lora_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
@ -436,6 +530,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -811,6 +811,8 @@ class FlashCausalLM(Model):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
@ -824,6 +826,8 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -839,6 +843,8 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
@ -966,6 +972,10 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
batch_lora_adapter_mask = torch.zeros(
|
||||
seqlen, dtype=torch.bool, device=self.device
|
||||
)
|
||||
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -981,6 +991,8 @@ class FlashCausalLM(Model):
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -1051,6 +1063,15 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
|
||||
for i, r in enumerate(batch.requests):
|
||||
if r.adapter_id:
|
||||
lora_index = int(r.adapter_id)
|
||||
lora_indices[i] = lora_index
|
||||
batch_lora_adapter_mask[i] = True
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -1063,6 +1084,8 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -13,7 +14,9 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
hub,
|
||||
)
|
||||
from text_generation_server.utils.weights import load_adaptor_weights
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -29,6 +32,7 @@ class FlashLlama(FlashCausalLM):
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
lora_adapter_ids: Optional[list] = [],
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -71,7 +75,7 @@ class FlashLlama(FlashCausalLM):
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
@ -192,6 +192,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
@ -203,6 +204,7 @@ def serve(
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
@ -224,6 +226,7 @@ def serve(
|
||||
try:
|
||||
model = get_model(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
@ -262,6 +265,13 @@ def serve(
|
||||
set_model_id(model_id)
|
||||
asyncio.run(
|
||||
serve_inner(
|
||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
from safetensors.torch import load_file
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
@ -9,6 +10,27 @@ import json
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
# TODO: improve how the weights are loaded
|
||||
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
|
||||
adapter_weights = {}
|
||||
if local_path.exists() and local_path.is_dir():
|
||||
local_files = list(local_path.glob(f"*{extension}"))
|
||||
if not local_files:
|
||||
raise FileNotFoundError(
|
||||
f"No local weights found in {model_id} with extension {extension}"
|
||||
)
|
||||
for filename in local_files:
|
||||
adapter_weights.update(load_file(filename))
|
||||
|
||||
# TODO: remove (no need to sort)
|
||||
# sorted on the the layer number (index 4 in the key)
|
||||
sorted_keys = sorted(
|
||||
adapter_weights.keys(),
|
||||
key=lambda x: int(x.split(".")[4]),
|
||||
)
|
||||
return (adapter_weights, sorted_keys)
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user