mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
import torch
|
|
import os
|
|
from loguru import logger
|
|
from typing import Dict
|
|
|
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
|
# This is overridden by the cli
|
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
|
if cuda_graphs is not None:
|
|
try:
|
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
|
)
|
|
else:
|
|
cuda_graphs = None
|
|
|
|
|
|
# sorting the cuda graphs in descending order helps reduce the
|
|
# memory impact and results in less memory usage
|
|
if cuda_graphs is not None:
|
|
cuda_graphs.sort(reverse=True)
|
|
|
|
|
|
CUDA_GRAPHS = cuda_graphs
|
|
|
|
# This is overridden at model loading.
|
|
global MODEL_ID
|
|
MODEL_ID = None
|
|
|
|
|
|
def set_model_id(model_id: str):
|
|
global MODEL_ID
|
|
MODEL_ID = model_id
|
|
|
|
|
|
# NOTE: eventually we should move this into the router and pass back the
|
|
# index in all cases.
|
|
global ADAPTER_TO_INDEX
|
|
ADAPTER_TO_INDEX: Dict[str, int] = None
|
|
|
|
|
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
|
global ADAPTER_TO_INDEX
|
|
ADAPTER_TO_INDEX = adapter_to_index
|