mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: first draft load multiple lora
This commit is contained in:
parent
85dfc39222
commit
db3d8e6518
@ -157,6 +157,7 @@ async fn prefill(
|
|||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
@ -107,6 +107,8 @@ message Request {
|
|||||||
bool prefill_logprobs = 6;
|
bool prefill_logprobs = 6;
|
||||||
/// Return most likely n tokens
|
/// Return most likely n tokens
|
||||||
uint32 top_n_tokens = 7;
|
uint32 top_n_tokens = 7;
|
||||||
|
/// LORA adapter id
|
||||||
|
optional string adapter_id = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -154,6 +154,7 @@ impl Client {
|
|||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
@ -290,6 +290,7 @@ impl State {
|
|||||||
entry.request.stopping_parameters.clone(),
|
entry.request.stopping_parameters.clone(),
|
||||||
)),
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -429,6 +430,7 @@ mod tests {
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
|
adapter_id: None,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
@ -298,6 +298,11 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub grammar: Option<GrammarType>,
|
pub grammar: Option<GrammarType>,
|
||||||
|
|
||||||
|
/// Lora adapter id
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
@ -324,6 +329,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
|
adapter_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -668,6 +668,7 @@ async fn completions(
|
|||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@ -1092,6 +1093,7 @@ async fn chat_completions(
|
|||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar: typed_grammar,
|
grammar: typed_grammar,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -202,6 +202,7 @@ impl Validation {
|
|||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
grammar,
|
grammar,
|
||||||
|
adapter_id,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -383,6 +384,7 @@ impl Validation {
|
|||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
adapter_id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||||||
pub parameters: ValidParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: ValidStoppingParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
pub top_n_tokens: u32,
|
pub top_n_tokens: u32,
|
||||||
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
|
@ -78,6 +78,14 @@ def serve(
|
|||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
|
# TODO: determine if this api makes sense
|
||||||
|
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||||
|
|
||||||
|
# split on comma and strip whitespace
|
||||||
|
lora_adapter_ids = (
|
||||||
|
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||||
|
)
|
||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = None if dtype is None else dtype.value
|
||||||
@ -92,6 +100,7 @@ def serve(
|
|||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id,
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
@ -253,6 +253,7 @@ for data in ModelType:
|
|||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
@ -595,6 +596,7 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
|
@ -88,9 +88,11 @@ def load_attention(config, prefix, weights):
|
|||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
|
all_adapter_weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
@ -122,6 +124,29 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.index = index
|
||||||
|
self.adapter_weights = {}
|
||||||
|
for adapter_id, adapter_weights in all_adapter_weights.items():
|
||||||
|
filtered_keys = list(
|
||||||
|
filter(
|
||||||
|
lambda x: x.startswith(
|
||||||
|
f"base_model.model.model.layers.{index}.self_attn"
|
||||||
|
),
|
||||||
|
adapter_weights.keys(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.adapter_weights[adapter_id] = {
|
||||||
|
key: torch.tensor(
|
||||||
|
adapter_weights[key],
|
||||||
|
device=weights.device,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
).T
|
||||||
|
for key in filtered_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
self.index_to_key = {
|
||||||
|
i: key for i, key in enumerate(self.adapter_weights.keys())
|
||||||
|
}
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -134,6 +159,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def get_adapter_weights(self, lora_index):
|
||||||
|
adapter_id = self.index_to_key[lora_index]
|
||||||
|
q_proj_lora_a = self.adapter_weights[adapter_id][
|
||||||
|
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
|
||||||
|
]
|
||||||
|
q_proj_lora_b = self.adapter_weights[adapter_id][
|
||||||
|
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
v_proj_lora_a = self.adapter_weights[adapter_id][
|
||||||
|
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
|
||||||
|
]
|
||||||
|
v_proj_lora_b = self.adapter_weights[adapter_id][
|
||||||
|
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
|
||||||
|
]
|
||||||
|
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -145,6 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
batch_lora_adapter_mask,
|
||||||
|
lora_indices,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -157,6 +201,40 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
|
||||||
|
self.get_adapter_weights(
|
||||||
|
# TODO: dont just assume the first adapter
|
||||||
|
lora_indices[0].item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_adapted = torch.matmul(
|
||||||
|
hidden_states,
|
||||||
|
torch.matmul(
|
||||||
|
q_proj_lora_a,
|
||||||
|
q_proj_lora_b,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
value_adapted = torch.matmul(
|
||||||
|
hidden_states,
|
||||||
|
torch.matmul(
|
||||||
|
v_proj_lora_a,
|
||||||
|
v_proj_lora_b,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = query.size(0)
|
||||||
|
|
||||||
|
# TODO: improve this to avoid unnecessary work
|
||||||
|
# mask across batch and within lora adapters
|
||||||
|
query[batch_lora_adapter_mask] += query_adapted.view(
|
||||||
|
batch_size, self.num_heads, self.head_size
|
||||||
|
)[batch_lora_adapter_mask]
|
||||||
|
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
|
||||||
|
batch_size, self.num_key_value_heads, self.head_size
|
||||||
|
)[batch_lora_adapter_mask]
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
@ -261,10 +339,14 @@ class LlamaMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights, all_adapter_weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashLlamaAttention(
|
self.self_attn = FlashLlamaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
index=index,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
all_adapter_weights=all_adapter_weights,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
@ -289,6 +371,8 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
batch_lora_adapter_mask,
|
||||||
|
lora_indices,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -303,6 +387,8 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
batch_lora_adapter_mask,
|
||||||
|
lora_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -316,7 +402,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaModel(torch.nn.Module):
|
class FlashLlamaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, all_adapter_weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -325,6 +411,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{layer_id}"
|
f"model.layers.{layer_id}"
|
||||||
if not prefix
|
if not prefix
|
||||||
@ -332,6 +419,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
all_adapter_weights=all_adapter_weights,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -360,6 +448,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
batch_lora_adapter_mask: Optional[List[str]],
|
||||||
|
lora_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -382,6 +472,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
batch_lora_adapter_mask,
|
||||||
|
lora_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -390,7 +482,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, all_adapter_weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
@ -399,7 +491,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
@ -423,6 +515,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
batch_lora_adapter_mask: Optional[List[str]] = None,
|
||||||
|
lora_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@ -436,6 +530,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
|
lora_indices=lora_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -811,6 +811,8 @@ class FlashCausalLM(Model):
|
|||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||||
|
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# Run once outside to warmup
|
# Run once outside to warmup
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
@ -824,6 +826,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
|
lora_indices=lora_indices,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -839,6 +843,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
|
lora_indices=lora_indices,
|
||||||
)
|
)
|
||||||
self.cuda_graphs[bs]["logits"] = logits
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
@ -966,6 +972,10 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
batch_lora_adapter_mask = torch.zeros(
|
||||||
|
seqlen, dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
@ -981,6 +991,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
|
lora_indices=lora_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1051,6 +1063,15 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
|
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||||
|
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
for i, r in enumerate(batch.requests):
|
||||||
|
if r.adapter_id:
|
||||||
|
lora_index = int(r.adapter_id)
|
||||||
|
lora_indices[i] = lora_index
|
||||||
|
batch_lora_adapter_mask[i] = True
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -1063,6 +1084,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
|
lora_indices=lora_indices,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -13,7 +14,9 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
hub,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import load_adaptor_weights
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -29,6 +32,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
lora_adapter_ids: Optional[list] = [],
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -71,7 +75,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
prefix = ""
|
prefix = ""
|
||||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -192,6 +192,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
@ -203,6 +204,7 @@ def serve(
|
|||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
@ -224,6 +226,7 @@ def serve(
|
|||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id,
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
@ -262,6 +265,13 @@ def serve(
|
|||||||
set_model_id(model_id)
|
set_model_id(model_id)
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from safetensors import safe_open, SafetensorError
|
from safetensors import safe_open, SafetensorError
|
||||||
|
from safetensors.torch import load_file
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@ -9,6 +10,27 @@ import json
|
|||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: improve how the weights are loaded
|
||||||
|
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
|
||||||
|
adapter_weights = {}
|
||||||
|
if local_path.exists() and local_path.is_dir():
|
||||||
|
local_files = list(local_path.glob(f"*{extension}"))
|
||||||
|
if not local_files:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No local weights found in {model_id} with extension {extension}"
|
||||||
|
)
|
||||||
|
for filename in local_files:
|
||||||
|
adapter_weights.update(load_file(filename))
|
||||||
|
|
||||||
|
# TODO: remove (no need to sort)
|
||||||
|
# sorted on the the layer number (index 4 in the key)
|
||||||
|
sorted_keys = sorted(
|
||||||
|
adapter_weights.keys(),
|
||||||
|
key=lambda x: int(x.split(".")[4]),
|
||||||
|
)
|
||||||
|
return (adapter_weights, sorted_keys)
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user