mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-29 22:12:07 +00:00
feat: baseline impl single request multi lora support
This commit is contained in:
parent
a046c303f7
commit
c661631225
@ -121,6 +121,7 @@ def download_weights(
|
|||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
merge_lora: bool = False,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -151,18 +152,25 @@ def download_weights(
|
|||||||
) is not None
|
) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
try:
|
# TODO: maybe reverse the default value of merge_lora?
|
||||||
adapter_config_filename = hf_hub_download(
|
# currently by default we don't merge the weights with the base model
|
||||||
model_id, revision=revision, filename="adapter_config.json"
|
if merge_lora:
|
||||||
)
|
try:
|
||||||
utils.download_and_unload_peft(
|
adapter_config_filename = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="adapter_config.json"
|
||||||
|
)
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
is_local_model = True
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
utils.peft.download_peft(
|
||||||
model_id, revision, trust_remote_code=trust_remote_code
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
is_local_model = True
|
|
||||||
utils.weight_files(model_id, revision, extension)
|
|
||||||
return
|
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
@ -92,7 +92,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
all_adapter_weights,
|
lora_weights,
|
||||||
|
lora_configs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
@ -126,36 +127,24 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
self.index = index
|
self.index = index
|
||||||
self.adapter_weights = {}
|
self.adapter_weights = {}
|
||||||
adapter_names = list(all_adapter_weights.keys())
|
adapter_names = list(lora_weights.keys())
|
||||||
|
|
||||||
self.lora_a_matrix = torch.empty(
|
self.n_loras = len(adapter_names)
|
||||||
(len(adapter_names), 2, 4096, 8),
|
self.pre_multiplied_lora_matrices = torch.empty(
|
||||||
device=weights.device,
|
(self.n_loras, 2, self.hidden_size, self.hidden_size),
|
||||||
dtype=weights.dtype,
|
|
||||||
)
|
|
||||||
self.lora_b_matrix = torch.empty(
|
|
||||||
(len(adapter_names), 2, 8, 4096),
|
|
||||||
device=weights.device,
|
|
||||||
dtype=weights.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pre_multiplied_lora_matrix = torch.empty(
|
|
||||||
(len(adapter_names), 2, 4096, 4096),
|
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.key_to_index = {}
|
self.key_to_index = {}
|
||||||
self.index_to_key = {}
|
|
||||||
|
|
||||||
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
|
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
|
||||||
for adapter_index, adapter_name in enumerate(adapter_names):
|
for adapter_index, adapter_name in enumerate(adapter_names):
|
||||||
self.lora_alpha = 16.0
|
self.lora_alpha = lora_configs[adapter_name].lora_alpha
|
||||||
self.lora_r = 8.0
|
self.lora_r = lora_configs[adapter_name].r
|
||||||
self.lora_scale = self.lora_alpha / self.lora_r
|
self.lora_scale = self.lora_alpha / self.lora_r
|
||||||
self.key_to_index[adapter_name] = adapter_index
|
self.key_to_index[adapter_name] = adapter_index
|
||||||
self.index_to_key[adapter_index] = adapter_name
|
adapter_weights = lora_weights[adapter_name]
|
||||||
adapter_weights = all_adapter_weights[adapter_name]
|
|
||||||
for target_index, target in enumerate(["q", "v"]):
|
for target_index, target in enumerate(["q", "v"]):
|
||||||
adapter_weight_a = adapter_weights.get_tensor(
|
adapter_weight_a = adapter_weights.get_tensor(
|
||||||
f"{lora_prefix}.{target}_proj.lora_A.weight"
|
f"{lora_prefix}.{target}_proj.lora_A.weight"
|
||||||
@ -168,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
adapter_weight_b.T,
|
adapter_weight_b.T,
|
||||||
).contiguous()
|
).contiguous()
|
||||||
|
|
||||||
self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = (
|
self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = (
|
||||||
pre_multiplied_lora_matrix
|
pre_multiplied_lora_matrix
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -209,16 +198,26 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
batch_size = query.size(0)
|
batch_size = query.size(0)
|
||||||
query_adapted = (
|
|
||||||
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0])
|
|
||||||
.squeeze(0)
|
|
||||||
.view(batch_size, self.num_heads, self.head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
value_adapted = (
|
# hidden states without LoRA
|
||||||
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1])
|
hs_wl = hidden_states[lora_indices == -1]
|
||||||
.squeeze(0)
|
|
||||||
.view(batch_size, self.num_key_value_heads, self.head_size)
|
adapted_query_states = [hs_wl]
|
||||||
|
adapted_value_states = [hs_wl]
|
||||||
|
|
||||||
|
for ind in range(self.n_loras):
|
||||||
|
mask = lora_indices == ind
|
||||||
|
hs_sub = hidden_states[mask]
|
||||||
|
mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0])
|
||||||
|
mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1])
|
||||||
|
adapted_query_states.append(mat_q)
|
||||||
|
adapted_value_states.append(mat_v)
|
||||||
|
|
||||||
|
query_adapted = torch.cat(adapted_query_states, dim=0).view(
|
||||||
|
batch_size, self.num_heads, self.head_size
|
||||||
|
)
|
||||||
|
value_adapted = torch.cat(adapted_value_states, dim=0).view(
|
||||||
|
batch_size, self.num_key_value_heads, self.head_size
|
||||||
)
|
)
|
||||||
|
|
||||||
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
||||||
@ -328,14 +327,15 @@ class LlamaMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, index, prefix, config, weights, all_adapter_weights):
|
def __init__(self, index, prefix, config, weights, lora_weights, lora_configs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashLlamaAttention(
|
self.self_attn = FlashLlamaAttention(
|
||||||
index=index,
|
index=index,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
all_adapter_weights=all_adapter_weights,
|
lora_weights=lora_weights,
|
||||||
|
lora_configs=lora_configs,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
@ -391,7 +391,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaModel(torch.nn.Module):
|
class FlashLlamaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, all_adapter_weights):
|
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -408,7 +408,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
all_adapter_weights=all_adapter_weights,
|
lora_weights=lora_weights,
|
||||||
|
lora_configs=lora_configs,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -471,7 +472,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, all_adapter_weights):
|
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
@ -480,7 +481,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
|
self.model = FlashLlamaModel(
|
||||||
|
prefix, config, weights, lora_weights, lora_configs
|
||||||
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
|
@ -1069,7 +1069,8 @@ class FlashCausalLM(Model):
|
|||||||
for i, r in enumerate(batch.requests):
|
for i, r in enumerate(batch.requests):
|
||||||
if r.adapter_id:
|
if r.adapter_id:
|
||||||
lora_index = self.model.get_lora_index(r.adapter_id)
|
lora_index = self.model.get_lora_index(r.adapter_id)
|
||||||
lora_indices[i] = lora_index
|
input_length = batch.input_lengths[i]
|
||||||
|
lora_indices[i : i + input_length] = lora_index
|
||||||
batch_lora_adapter_mask[i] = True
|
batch_lora_adapter_mask[i] = True
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
@ -16,11 +16,11 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
hub,
|
hub,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import load_adaptor_weights
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.lora import LoraConfig
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
@ -75,7 +75,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
prefix = ""
|
prefix = ""
|
||||||
model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
|
model = FlashLlamaForCausalLM(
|
||||||
|
prefix, config, weights, lora_weights, lora_configs
|
||||||
|
)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -86,6 +86,18 @@ def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
||||||
|
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||||
|
# see _weight_files_from_dir, that's also what is done there
|
||||||
|
root, _, files = next(os.walk(str(d)))
|
||||||
|
filenames = [
|
||||||
|
os.path.join(root, f)
|
||||||
|
for f in files
|
||||||
|
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
||||||
|
]
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_revision_directory(
|
def _get_cached_revision_directory(
|
||||||
model_id: str, revision: Optional[str]
|
model_id: str, revision: Optional[str]
|
||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
|
72
server/text_generation_server/utils/lora.py
Normal file
72
server/text_generation_server/utils/lora.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import json
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
hub,
|
||||||
|
)
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
alpha_pattern=None,
|
||||||
|
auto_mapping=None,
|
||||||
|
base_model_name_or_path="",
|
||||||
|
bias="none",
|
||||||
|
fan_in_fan_out=False,
|
||||||
|
inference_mode=True,
|
||||||
|
init_lora_weights=True,
|
||||||
|
layer_replication=None,
|
||||||
|
layers_pattern=None,
|
||||||
|
layers_to_transform=None,
|
||||||
|
loftq_config=None,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.1,
|
||||||
|
megatron_config=None,
|
||||||
|
megatron_core="megatron.core",
|
||||||
|
modules_to_save=None,
|
||||||
|
peft_type="LORA",
|
||||||
|
r=8,
|
||||||
|
rank_pattern=None,
|
||||||
|
revision=None,
|
||||||
|
target_modules=None,
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
use_dora=False,
|
||||||
|
use_rslora=False,
|
||||||
|
):
|
||||||
|
self.alpha_pattern = alpha_pattern or {}
|
||||||
|
self.auto_mapping = auto_mapping
|
||||||
|
self.base_model_name_or_path = base_model_name_or_path
|
||||||
|
self.bias = bias
|
||||||
|
self.fan_in_fan_out = fan_in_fan_out
|
||||||
|
self.inference_mode = inference_mode
|
||||||
|
self.init_lora_weights = init_lora_weights
|
||||||
|
self.layer_replication = layer_replication
|
||||||
|
self.layers_pattern = layers_pattern
|
||||||
|
self.layers_to_transform = layers_to_transform
|
||||||
|
self.loftq_config = loftq_config or {}
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
self.lora_dropout = lora_dropout
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
self.megatron_core = megatron_core
|
||||||
|
self.modules_to_save = modules_to_save
|
||||||
|
self.peft_type = peft_type
|
||||||
|
self.r = r
|
||||||
|
self.rank_pattern = rank_pattern or {}
|
||||||
|
self.revision = revision
|
||||||
|
self.target_modules = target_modules or ["q_proj", "v_proj"]
|
||||||
|
self.task_type = task_type
|
||||||
|
self.use_dora = use_dora
|
||||||
|
self.use_rslora = use_rslora
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, filename):
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
return cls(**json_data)
|
||||||
|
|
||||||
|
# TODO: support fetching the model from the hub if it's not in the cache
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, adapter_id, revision=None):
|
||||||
|
d = hub._get_cached_revision_directory(adapter_id, revision)
|
||||||
|
filename = os.path.join(d, "adapter_config.json")
|
||||||
|
return cls.from_file(filename)
|
@ -43,3 +43,24 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||||
model.config.save_pretrained(cache_dir)
|
model.config.save_pretrained(cache_dir)
|
||||||
tokenizer.save_pretrained(cache_dir)
|
tokenizer.save_pretrained(cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def download_peft(model_id, revision, trust_remote_code):
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
try:
|
||||||
|
_model = AutoPeftModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
logger.info("Peft model downloaded.")
|
||||||
|
@ -10,27 +10,6 @@ import json
|
|||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
# TODO: improve how the weights are loaded
|
|
||||||
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
|
|
||||||
adapter_weights = {}
|
|
||||||
if local_path.exists() and local_path.is_dir():
|
|
||||||
local_files = list(local_path.glob(f"*{extension}"))
|
|
||||||
if not local_files:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"No local weights found in {model_id} with extension {extension}"
|
|
||||||
)
|
|
||||||
for filename in local_files:
|
|
||||||
adapter_weights.update(load_file(filename))
|
|
||||||
|
|
||||||
# TODO: remove (no need to sort)
|
|
||||||
# sorted on the the layer number (index 4 in the key)
|
|
||||||
sorted_keys = sorted(
|
|
||||||
adapter_weights.keys(),
|
|
||||||
key=lambda x: int(x.split(".")[4]),
|
|
||||||
)
|
|
||||||
return (adapter_weights, sorted_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user