mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
# Origin: https://github.com/predibase/lorax
|
|
# Path: lorax/server/lorax_server/adapters/config.py
|
|
# License: Apache License Version 2.0, January 2004
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
|
|
|
import torch
|
|
|
|
from text_generation_server.adapters.weights import AdapterWeights
|
|
|
|
if TYPE_CHECKING:
|
|
from text_generation_server.models.model import Model
|
|
|
|
|
|
@dataclass
|
|
class ModuleMap:
|
|
module_name: str
|
|
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
|
|
|
|
|
@dataclass
|
|
class AdapterConfig(ABC):
|
|
base_model_name_or_path: str
|
|
|
|
@abstractmethod
|
|
def map_weights_for_model(
|
|
self,
|
|
adapter_weights: Dict[int, AdapterWeights],
|
|
weight_names: Tuple[str],
|
|
) -> Tuple[ModuleMap, Set[str]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_batched_adapter_weights(
|
|
self,
|
|
model: "Model",
|
|
module_map: ModuleMap,
|
|
layer_type: str,
|
|
unused_weight_names: Set[str],
|
|
dynamic: bool,
|
|
) -> Optional[AdapterWeights]:
|
|
pass
|