From c661631225ac227360b4094329189d3dfe438ba4 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 4 Jun 2024 20:07:28 +0000 Subject: [PATCH] feat: baseline impl single request multi lora support --- server/text_generation_server/cli.py | 28 ++++--- .../custom_modeling/flash_llama_modeling.py | 75 ++++++++++--------- .../models/flash_causal_lm.py | 3 +- .../models/flash_llama.py | 6 +- server/text_generation_server/utils/hub.py | 12 +++ server/text_generation_server/utils/lora.py | 72 ++++++++++++++++++ server/text_generation_server/utils/peft.py | 21 ++++++ .../text_generation_server/utils/weights.py | 21 ------ 8 files changed, 168 insertions(+), 70 deletions(-) create mode 100644 server/text_generation_server/utils/lora.py diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b18deabc..87721097 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -121,6 +121,7 @@ def download_weights( logger_level: str = "INFO", json_output: bool = False, trust_remote_code: bool = False, + merge_lora: bool = False, ): # Remove default handler logger.remove() @@ -151,18 +152,25 @@ def download_weights( ) is not None if not is_local_model: - try: - adapter_config_filename = hf_hub_download( - model_id, revision=revision, filename="adapter_config.json" - ) - utils.download_and_unload_peft( + # TODO: maybe reverse the default value of merge_lora? + # currently by default we don't merge the weights with the base model + if merge_lora: + try: + adapter_config_filename = hf_hub_download( + model_id, revision=revision, filename="adapter_config.json" + ) + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + is_local_model = True + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + else: + utils.peft.download_peft( model_id, revision, trust_remote_code=trust_remote_code ) - is_local_model = True - utils.weight_files(model_id, revision, extension) - return - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass try: import json diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b41712f4..fac54480 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -92,7 +92,8 @@ class FlashLlamaAttention(torch.nn.Module): prefix: str, config, weights, - all_adapter_weights, + lora_weights, + lora_configs, ): super().__init__() self.num_heads = config.num_attention_heads @@ -126,36 +127,24 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) self.index = index self.adapter_weights = {} - adapter_names = list(all_adapter_weights.keys()) + adapter_names = list(lora_weights.keys()) - self.lora_a_matrix = torch.empty( - (len(adapter_names), 2, 4096, 8), - device=weights.device, - dtype=weights.dtype, - ) - self.lora_b_matrix = torch.empty( - (len(adapter_names), 2, 8, 4096), - device=weights.device, - dtype=weights.dtype, - ) - - self.pre_multiplied_lora_matrix = torch.empty( - (len(adapter_names), 2, 4096, 4096), + self.n_loras = len(adapter_names) + self.pre_multiplied_lora_matrices = torch.empty( + (self.n_loras, 2, self.hidden_size, self.hidden_size), device=weights.device, dtype=weights.dtype, ) self.key_to_index = {} - self.index_to_key = {} lora_prefix = f"base_model.model.model.layers.{index}.self_attn" for adapter_index, adapter_name in enumerate(adapter_names): - self.lora_alpha = 16.0 - self.lora_r = 8.0 + self.lora_alpha = lora_configs[adapter_name].lora_alpha + self.lora_r = lora_configs[adapter_name].r self.lora_scale = self.lora_alpha / self.lora_r self.key_to_index[adapter_name] = adapter_index - self.index_to_key[adapter_index] = adapter_name - adapter_weights = all_adapter_weights[adapter_name] + adapter_weights = lora_weights[adapter_name] for target_index, target in enumerate(["q", "v"]): adapter_weight_a = adapter_weights.get_tensor( f"{lora_prefix}.{target}_proj.lora_A.weight" @@ -168,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module): adapter_weight_b.T, ).contiguous() - self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = ( + self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = ( pre_multiplied_lora_matrix ) @@ -209,16 +198,26 @@ class FlashLlamaAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) batch_size = query.size(0) - query_adapted = ( - torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0]) - .squeeze(0) - .view(batch_size, self.num_heads, self.head_size) - ) - value_adapted = ( - torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1]) - .squeeze(0) - .view(batch_size, self.num_key_value_heads, self.head_size) + # hidden states without LoRA + hs_wl = hidden_states[lora_indices == -1] + + adapted_query_states = [hs_wl] + adapted_value_states = [hs_wl] + + for ind in range(self.n_loras): + mask = lora_indices == ind + hs_sub = hidden_states[mask] + mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0]) + mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1]) + adapted_query_states.append(mat_q) + adapted_value_states.append(mat_v) + + query_adapted = torch.cat(adapted_query_states, dim=0).view( + batch_size, self.num_heads, self.head_size + ) + value_adapted = torch.cat(adapted_value_states, dim=0).view( + batch_size, self.num_key_value_heads, self.head_size ) query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask] @@ -328,14 +327,15 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): - def __init__(self, index, prefix, config, weights, all_adapter_weights): + def __init__(self, index, prefix, config, weights, lora_weights, lora_configs): super().__init__() self.self_attn = FlashLlamaAttention( index=index, prefix=f"{prefix}.self_attn", config=config, weights=weights, - all_adapter_weights=all_adapter_weights, + lora_weights=lora_weights, + lora_configs=lora_configs, ) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -391,7 +391,7 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): - def __init__(self, prefix, config, weights, all_adapter_weights): + def __init__(self, prefix, config, weights, lora_weights, lora_configs): super().__init__() process_group = weights.process_group @@ -408,7 +408,8 @@ class FlashLlamaModel(torch.nn.Module): ), config=config, weights=weights, - all_adapter_weights=all_adapter_weights, + lora_weights=lora_weights, + lora_configs=lora_configs, ) for layer_id in range(config.num_hidden_layers) ] @@ -471,7 +472,7 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, all_adapter_weights): + def __init__(self, prefix, config, weights, lora_weights, lora_configs): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -480,7 +481,9 @@ class FlashLlamaForCausalLM(torch.nn.Module): ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights) + self.model = FlashLlamaModel( + prefix, config, weights, lora_weights, lora_configs + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0062cd55..0131c0e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1069,7 +1069,8 @@ class FlashCausalLM(Model): for i, r in enumerate(batch.requests): if r.adapter_id: lora_index = self.model.get_lora_index(r.adapter_id) - lora_indices[i] = lora_index + input_length = batch.input_lengths[i] + lora_indices[i : i + input_length] = lora_index batch_lora_adapter_mask[i] = True if cu_seqlen_prefill is not None or cuda_graph is None: diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 78b35276..c5d3ecac 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -16,11 +16,11 @@ from text_generation_server.utils import ( Weights, hub, ) -from text_generation_server.utils.weights import load_adaptor_weights tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.lora import LoraConfig class FlashLlama(FlashCausalLM): @@ -75,7 +75,9 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) prefix = "" - model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights) + model = FlashLlamaForCausalLM( + prefix, config, weights, lora_weights, lora_configs + ) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index d41700e8..db412aeb 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -86,6 +86,18 @@ def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]: return filenames +def _adapter_config_files_from_dir(d: Path) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_files_from_dir, that's also what is done there + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(".json") and "arguments" not in f and "args" not in f + ] + return filenames + + def _get_cached_revision_directory( model_id: str, revision: Optional[str] ) -> Optional[Path]: diff --git a/server/text_generation_server/utils/lora.py b/server/text_generation_server/utils/lora.py new file mode 100644 index 00000000..64a6724a --- /dev/null +++ b/server/text_generation_server/utils/lora.py @@ -0,0 +1,72 @@ +import json +from text_generation_server.utils import ( + hub, +) +import os + + +class LoraConfig: + def __init__( + self, + alpha_pattern=None, + auto_mapping=None, + base_model_name_or_path="", + bias="none", + fan_in_fan_out=False, + inference_mode=True, + init_lora_weights=True, + layer_replication=None, + layers_pattern=None, + layers_to_transform=None, + loftq_config=None, + lora_alpha=16, + lora_dropout=0.1, + megatron_config=None, + megatron_core="megatron.core", + modules_to_save=None, + peft_type="LORA", + r=8, + rank_pattern=None, + revision=None, + target_modules=None, + task_type="CAUSAL_LM", + use_dora=False, + use_rslora=False, + ): + self.alpha_pattern = alpha_pattern or {} + self.auto_mapping = auto_mapping + self.base_model_name_or_path = base_model_name_or_path + self.bias = bias + self.fan_in_fan_out = fan_in_fan_out + self.inference_mode = inference_mode + self.init_lora_weights = init_lora_weights + self.layer_replication = layer_replication + self.layers_pattern = layers_pattern + self.layers_to_transform = layers_to_transform + self.loftq_config = loftq_config or {} + self.lora_alpha = lora_alpha + self.lora_dropout = lora_dropout + self.megatron_config = megatron_config + self.megatron_core = megatron_core + self.modules_to_save = modules_to_save + self.peft_type = peft_type + self.r = r + self.rank_pattern = rank_pattern or {} + self.revision = revision + self.target_modules = target_modules or ["q_proj", "v_proj"] + self.task_type = task_type + self.use_dora = use_dora + self.use_rslora = use_rslora + + @classmethod + def from_file(cls, filename): + with open(filename, "r") as f: + json_data = json.load(f) + return cls(**json_data) + + # TODO: support fetching the model from the hub if it's not in the cache + @classmethod + def from_pretrained(cls, adapter_id, revision=None): + d = hub._get_cached_revision_directory(adapter_id, revision) + filename = os.path.join(d, "adapter_config.json") + return cls.from_file(filename) diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 48ca264b..5aaeb5ac 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -43,3 +43,24 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): model.save_pretrained(cache_dir, safe_serialization=True) model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir) + + +def download_peft(model_id, revision, trust_remote_code): + torch_dtype = torch.float16 + try: + _model = AutoPeftModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + except Exception: + _model = AutoPeftModelForSeq2SeqLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + logger.info("Peft model downloaded.") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index ab18e0c7..efede312 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -10,27 +10,6 @@ import json from text_generation_server.utils.log import log_once -# TODO: improve how the weights are loaded -def load_adaptor_weights(model_id, local_path, extension=".safetensors"): - adapter_weights = {} - if local_path.exists() and local_path.is_dir(): - local_files = list(local_path.glob(f"*{extension}")) - if not local_files: - raise FileNotFoundError( - f"No local weights found in {model_id} with extension {extension}" - ) - for filename in local_files: - adapter_weights.update(load_file(filename)) - - # TODO: remove (no need to sort) - # sorted on the the layer number (index 4 in the key) - sorted_keys = sorted( - adapter_weights.keys(), - key=lambda x: int(x.split(".")[4]), - ) - return (adapter_weights, sorted_keys) - - class Weights: def __init__( self,